Codebase cleanup

Removed a lot of legacy stuff I have no intent on using again.
Plan is to shape this repo into something more extensible (get it? hah!)
This commit is contained in:
James Betker 2020-10-13 20:56:39 -06:00
parent e620fc05ba
commit 24792bdb4f
61 changed files with 17 additions and 3128 deletions

View File

@ -1,4 +0,0 @@
[style]
BASED_ON_STYLE = pep8
COLUMN_LIMIT = 100
SPLIT_BEFORE_NAMED_ASSIGNS = false

View File

@ -1,124 +0,0 @@
import random
import numpy as np
import cv2
import lmdb
import torch
import torch.utils.data as data
import data.util as util
from PIL import Image
from io import BytesIO
import torchvision.transforms.functional as F
class DownsampleDataset(data.Dataset):
"""
Reads an unpaired HQ and LQ image. Clips both images to the expected input sizes of the model. Produces a
downsampled LQ image from the HQ image and feeds that as well.
"""
def __init__(self, opt):
super(DownsampleDataset, self).__init__()
self.opt = opt
self.data_type = self.opt['data_type']
self.paths_LQ, self.paths_GT = None, None
self.sizes_LQ, self.sizes_GT = None, None
self.LQ_env, self.GT_env = None, None # environments for lmdb
self.doCrop = self.opt['doCrop']
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
self.data_sz_mismatch_ok = opt['mismatched_Data_OK']
assert self.paths_GT, 'Error: GT path is empty.'
assert self.paths_LQ, 'LQ is required for downsampling.'
if not self.data_sz_mismatch_ok:
assert len(self.paths_LQ) == len(
self.paths_GT
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
len(self.paths_LQ), len(self.paths_GT))
self.random_scale_list = [1]
def _init_lmdb(self):
# https://github.com/chainer/chainermn/issues/129
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
meminit=False)
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
def __getitem__(self, index):
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
self._init_lmdb()
scale = self.opt['scale']
GT_size = self.opt['target_size'] * scale
# get GT image
GT_path = self.paths_GT[index % len(self.paths_GT)]
resolution = [int(s) for s in self.sizes_GT[index].split('_')
] if self.data_type == 'lmdb' else None
img_GT = util.read_img(self.GT_env, GT_path, resolution)
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
img_GT = util.modcrop(img_GT, scale)
if self.opt['color']: # change color space if necessary
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
# get LQ image
LQ_path = self.paths_LQ[index % len(self.paths_LQ)]
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
] if self.data_type == 'lmdb' else None
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
if self.opt['phase'] == 'train':
H, W, _ = img_GT.shape
assert H >= GT_size and W >= GT_size
H, W, C = img_LQ.shape
LQ_size = GT_size // scale
if self.doCrop:
# randomly crop
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
else:
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
# augmentation - flip, rotate
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
self.opt['use_rot'])
# BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
# HQ needs to go to a PIL image to perform the compression-artifact transformation.
H, W, _ = img_GT.shape
img_GT = (img_GT * 255).astype(np.uint8)
img_GT = Image.fromarray(img_GT)
if self.opt['use_compression_artifacts']:
qf = random.randrange(15, 100)
corruption_buffer = BytesIO()
img_GT.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
corruption_buffer.seek(0)
img_GT = Image.open(corruption_buffer)
# Generate a downsampled image from HQ for feature and PIX losses.
img_Downsampled = F.resize(img_GT, (int(H / scale), int(W / scale)))
img_GT = F.to_tensor(img_GT)
img_Downsampled = F.to_tensor(img_Downsampled)
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
# This may seem really messed up, but let me explain:
# The goal is to re-use existing code as much as possible. SRGAN_model was coded to supersample, not downsample,
# but it can be retrofitted. To do so, we need to "trick" it. In this case the "input" is the HQ image and the
# "output" is the LQ image. SRGAN_model will be using a Generator and a Discriminator which already know this,
# we just need to trick its logic into following this rules.
# Do this by setting LQ(which is the input into the models)=img_GT and GT(which is the expected outpuut)=img_LQ.
# PIX is used as a reference for the pixel loss. Use the manually downsampled image for this.
return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path}
def __len__(self):
return max(len(self.paths_GT), len(self.paths_LQ))

View File

@ -1,239 +0,0 @@
import random
import numpy as np
import cv2
import lmdb
import torch
import torch.utils.data as data
import data.util as util
from PIL import Image, ImageOps
from io import BytesIO
import torchvision.transforms.functional as F
class LQGTDataset(data.Dataset):
"""
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
If only GT images are provided, generate LQ images on-the-fly.
"""
def get_lq_path(self, i):
which_lq = random.randint(0, len(self.paths_LQ)-1)
return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])]
def __init__(self, opt):
super(LQGTDataset, self).__init__()
self.opt = opt
self.data_type = self.opt['data_type']
self.paths_LQ, self.paths_GT = None, None
self.sizes_LQ, self.sizes_GT = None, None
self.paths_PIX, self.sizes_PIX = None, None
self.paths_GAN, self.sizes_GAN = None, None
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights'])
if 'dataroot_LQ' in opt.keys():
self.paths_LQ = []
if isinstance(opt['dataroot_LQ'], list):
# Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and
# we want the model to learn them all.
for dr_lq in opt['dataroot_LQ']:
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq)
self.paths_LQ.append(lq_path)
else:
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
self.paths_LQ.append(lq_path)
self.doCrop = opt['doCrop']
if 'dataroot_PIX' in opt.keys():
self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX'])
# dataroot_GAN is an alternative source of LR images specifically for use in computing the GAN loss, where
# LR and HR do not need to be paired.
if 'dataroot_GAN' in opt.keys():
self.paths_GAN, self.sizes_GAN = util.get_image_paths(self.data_type, opt['dataroot_GAN'])
print('loaded %i images for use in training GAN only.' % (self.sizes_GAN,))
assert self.paths_GT, 'Error: GT path is empty.'
self.random_scale_list = [1]
def _init_lmdb(self):
# https://github.com/chainer/chainermn/issues/129
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
meminit=False)
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
if 'dataroot_PIX' in self.opt.keys():
self.PIX_env = lmdb.open(self.opt['dataroot_PIX'], readonly=True, lock=False, readahead=False,
meminit=False)
def motion_blur(self, image, size, angle):
k = np.zeros((size, size), dtype=np.float32)
k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32)
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size))
k = k * (1.0 / np.sum(k))
return cv2.filter2D(image, -1, k)
def __getitem__(self, index):
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
self._init_lmdb()
GT_path, LQ_path = None, None
scale = self.opt['scale']
GT_size = self.opt['target_size']
# get GT image
GT_path = self.paths_GT[index % len(self.paths_GT)]
resolution = [int(s) for s in self.sizes_GT[index].split('_')
] if self.data_type == 'lmdb' else None
img_GT = util.read_img(self.GT_env, GT_path, resolution)
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
img_GT = util.modcrop(img_GT, scale)
if self.opt['color']: # change color space if necessary
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
# get the pix image
if self.paths_PIX is not None:
PIX_path = self.paths_PIX[index % len(self.paths_PIX)]
img_PIX = util.read_img(self.PIX_env, PIX_path, resolution)
if self.opt['color']: # change color space if necessary
img_PIX = util.channel_convert(img_PIX.shape[2], self.opt['color'], [img_PIX])[0]
else:
img_PIX = img_GT
# get LQ image
if self.paths_LQ:
LQ_path = self.get_lq_path(index)
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
] if self.data_type == 'lmdb' else None
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train':
random_scale = random.choice(self.random_scale_list)
H_s, W_s, _ = img_GT.shape
def _mod(n, random_scale, scale, thres):
rlt = int(n * random_scale)
rlt = (rlt // scale) * scale
return thres if rlt < thres else rlt
H_s = _mod(H_s, random_scale, scale, GT_size)
W_s = _mod(W_s, random_scale, scale, GT_size)
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
if img_GT.ndim == 2:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
H, W, _ = img_GT.shape
# using matlab imresize
if scale == 1:
img_LQ = img_GT
else:
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)
img_GAN = None
if self.paths_GAN:
GAN_path = self.paths_GAN[index % self.sizes_GAN]
img_GAN = util.read_img(self.LQ_env, GAN_path)
# Enforce force_resize constraints.
h, w, _ = img_LQ.shape
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
h, w = (h - h % self.force_multiple), (w - w % self.force_multiple)
img_LQ = img_LQ[:h, :w, :]
h *= scale
w *= scale
img_GT = img_GT[:h, :w, :]
img_PIX = img_PIX[:h, :w, :]
if self.opt['phase'] == 'train':
H, W, _ = img_GT.shape
assert H >= GT_size and W >= GT_size
H, W, C = img_LQ.shape
LQ_size = GT_size // scale
if self.doCrop:
# randomly crop
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
if img_GAN is not None:
img_GAN = img_GAN[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
else:
if img_LQ.shape[0] != LQ_size:
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
if img_GAN is not None:
img_GAN = cv2.resize(img_GAN, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
if img_GT.shape[0] != GT_size:
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
if img_PIX.shape[0] != GT_size:
img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
if 'doResizeLoss' in self.opt.keys() and self.opt['doResizeLoss']:
r = random.randrange(0, 10)
if r > 5:
img_LQ = cv2.resize(img_LQ, (int(LQ_size/2), int(LQ_size/2)), interpolation=cv2.INTER_LINEAR)
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
# augmentation - flip, rotate
img_LQ, img_GT, img_PIX = util.augment([img_LQ, img_GT, img_PIX], self.opt['use_flip'],
self.opt['use_rot'])
if self.opt['use_blurring']:
# Pick randomly between gaussian, motion, or no blur.
blur_det = random.randint(0, 100)
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
if blur_det < 40:
blur_sig = int(random.randrange(0, blur_magnitude))
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
elif blur_det < 70:
img_LQ = self.motion_blur(img_LQ, random.randrange(1, blur_magnitude * 3), random.randint(0, 360))
if self.opt['color']: # change color space if necessary
img_LQ = util.channel_convert(C, self.opt['color'],
[img_LQ])[0] # TODO during val no definition
# BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
if img_GAN is not None:
img_GAN = cv2.cvtColor(img_GAN, cv2.COLOR_BGR2RGB)
img_PIX = cv2.cvtColor(img_PIX, cv2.COLOR_BGR2RGB)
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
img_LQ = (img_LQ * 255).astype(np.uint8)
img_LQ = Image.fromarray(img_LQ)
if self.opt['use_compression_artifacts'] and random.random() > .25:
qf = random.randrange(10, 70)
corruption_buffer = BytesIO()
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
corruption_buffer.seek(0)
img_LQ = Image.open(corruption_buffer)
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float()
img_LQ = F.to_tensor(img_LQ)
if img_GAN is not None:
img_GAN = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GAN, (2, 0, 1)))).float()
if 'lq_noise' in self.opt.keys():
lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255
img_LQ += lq_noise
if LQ_path is None:
LQ_path = GT_path
d = {'LQ': img_LQ, 'GT': img_GT, 'ref': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path}
if img_GAN is not None:
d['GAN'] = img_GAN
return d
def __len__(self):
return len(self.paths_GT)

View File

@ -1,71 +0,0 @@
import numpy as np
import lmdb
import torch
import torch.utils.data as data
import data.util as util
import torchvision.transforms.functional as F
from PIL import Image
import os.path as osp
import cv2
class LQDataset(data.Dataset):
'''Read LQ images only in the test phase.'''
def __init__(self, opt):
super(LQDataset, self).__init__()
self.opt = opt
self.data_type = self.opt['data_type']
if 'start_at' in self.opt.keys():
self.start_at = self.opt['start_at']
else:
self.start_at = 0
self.vertical_splits = self.opt['vertical_splits']
self.paths_LQ, self.paths_GT = None, None
self.LQ_env = None # environment for lmdb
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
self.paths_LQ = self.paths_LQ[self.start_at:]
assert self.paths_LQ, 'Error: LQ paths are empty.'
def _init_lmdb(self):
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
def __getitem__(self, index):
if self.data_type == 'lmdb' and self.LQ_env is None:
self._init_lmdb()
if self.vertical_splits > 0:
actual_index = int(index / self.vertical_splits)
else:
actual_index = index
# get LQ image
LQ_path = self.paths_LQ[actual_index]
img_LQ = Image.open(LQ_path)
if self.vertical_splits > 0:
w, h = img_LQ.size
split_index = (index % self.vertical_splits)
w_per_split = int(w / self.vertical_splits)
left = w_per_split * split_index
img_LQ = F.crop(img_LQ, 0, left, h, w_per_split)
# Enforce force_resize constraints.
h, w = img_LQ.size
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
h, w = (w - w % self.force_multiple), (h - h % self.force_multiple)
img_LQ = img_LQ.resize((w, h))
img_LQ = F.to_tensor(img_LQ)
img_name = osp.splitext(osp.basename(LQ_path))[0]
LQ_path = LQ_path.replace(img_name, img_name + "_%i" % (index % self.vertical_splits))
return {'LQ': img_LQ, 'LQ_path': LQ_path}
def __len__(self):
if self.vertical_splits > 0:
return len(self.paths_LQ) * self.vertical_splits
else:
return len(self.paths_LQ)

View File

@ -29,14 +29,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
def create_dataset(dataset_opt):
mode = dataset_opt['mode']
# datasets for image restoration
if mode == 'LQ':
from data.LQ_dataset import LQDataset as D
elif mode == 'LQGT':
from data.LQGT_dataset import LQGTDataset as D
# datasets for image corruption
elif mode == 'downsample':
from data.Downsample_dataset import DownsampleDataset as D
elif mode == 'fullimage':
if mode == 'fullimage':
from data.full_image_dataset import FullImageDataset as D
elif mode == 'single_image_extensible':
from data.single_image_dataset import SingleImageDataset as D

File diff suppressed because it is too large Load Diff

View File

@ -1,171 +0,0 @@
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
import models.networks as networks
import models.lr_scheduler as lr_scheduler
from .base_model import BaseModel
from models.loss import CharbonnierLoss
from apex import amp
logger = logging.getLogger('base')
class SRModel(BaseModel):
def __init__(self, opt):
super(SRModel, self).__init__(opt)
if opt['dist']:
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
train_opt = opt['train']
# define network and load pretrained models
self.netG = amp.initialize(networks.define_G(opt).to(self.device), opt_level=self.amp_level)
if opt['dist']:
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
elif opt['gpu_ids'] is not None:
self.netG = DataParallel(self.netG)
# print network
self.print_network()
self.load()
if self.is_train:
self.netG.train()
# loss
loss_type = train_opt['pixel_criterion']
if loss_type == 'l1':
self.cri_pix = nn.L1Loss().to(self.device)
elif loss_type == 'l2':
self.cri_pix = nn.MSELoss().to(self.device)
elif loss_type == 'cb':
self.cri_pix = CharbonnierLoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
self.l_pix_w = train_opt['pixel_weight']
# optimizers
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
optim_params = []
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
if v.requires_grad:
optim_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
weight_decay=wd_G,
betas=(train_opt['beta1'], train_opt['beta2']))
self.optimizers.append(self.optimizer_G)
# schedulers
if train_opt['lr_scheme'] == 'MultiStepLR':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
restarts=train_opt['restarts'],
weights=train_opt['restart_weights'],
gamma=train_opt['lr_gamma'],
clear_state=train_opt['clear_state']))
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.CosineAnnealingLR_Restart(
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
else:
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
self.log_dict = OrderedDict()
def feed_data(self, data, need_GT=True):
self.var_L = data['LQ'].to(self.device) # LQ
if need_GT:
self.real_H = data['GT'].to(self.device) # GT
def optimize_parameters(self, step):
self.optimizer_G.zero_grad()
self.fake_H = self.netG(self.var_L)
l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
l_pix.backward()
self.optimizer_G.step()
# set log
self.log_dict['l_pix'] = l_pix.item()
def test(self):
self.netG.eval()
with torch.no_grad():
self.fake_H = self.netG(self.var_L)
self.netG.train()
def test_x8(self):
# from https://github.com/thstkdgus35/EDSR-PyTorch
self.netG.eval()
def _transform(v, op):
# if self.precision != 'single': v = v.float()
v2np = v.data.cpu().numpy()
if op == 'v':
tfnp = v2np[:, :, :, ::-1].copy()
elif op == 'h':
tfnp = v2np[:, :, ::-1, :].copy()
elif op == 't':
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
ret = torch.Tensor(tfnp).to(self.device)
# if self.precision == 'half': ret = ret.half()
return ret
lr_list = [self.var_L]
for tf in 'v', 'h', 't':
lr_list.extend([_transform(t, tf) for t in lr_list])
with torch.no_grad():
sr_list = [self.netG(aug) for aug in lr_list]
for i in range(len(sr_list)):
if i > 3:
sr_list[i] = _transform(sr_list[i], 't')
if i % 4 > 1:
sr_list[i] = _transform(sr_list[i], 'h')
if (i % 4) % 2 == 1:
sr_list[i] = _transform(sr_list[i], 'v')
output_cat = torch.cat(sr_list, dim=0)
self.fake_H = output_cat.mean(dim=0, keepdim=True)
self.netG.train()
def get_current_log(self):
return self.log_dict
def get_current_visuals(self, need_GT=True):
out_dict = OrderedDict()
out_dict['LQ'] = self.var_L.detach().float().cpu()
out_dict['rlt'] = self.fake_H.detach().float().cpu()
if need_GT:
out_dict['GT'] = self.real_H.detach().float().cpu()
return out_dict
def print_network(self):
s, n = self.get_network_description(self.netG)
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
self.netG.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netG.__class__.__name__)
if self.rank <= 0:
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info(s)
def load(self):
load_path_G = self.opt['path']['pretrain_model_G']
if load_path_G is not None:
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
def save(self, iter_label):
self.save_network(self.netG, 'G', iter_label)

View File

@ -1,22 +0,0 @@
import logging
logger = logging.getLogger('base')
def create_model(opt):
model = opt['model']
# image restoration
if model == 'sr': # PSNR-oriented super resolution
from .SR_model import SRModel as M
elif model == 'srgan' or model == 'corruptgan' or model == 'spsrgan':
from .SRGAN_model import SRGANModel as M
elif model == 'feat':
from .feature_model import FeatureModel as M
elif model == 'spsr':
from .SPSR_model import SPSRModel as M
elif model == 'extensibletrainer':
from .ExtensibleTrainer import ExtensibleTrainer as M
else:
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
m = M(opt)
logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
return m

View File

@ -1,39 +0,0 @@
from torch.autograd import Function, Variable
from torch.nn.modules.module import Module
import channelnorm_cuda
class ChannelNormFunction(Function):
@staticmethod
def forward(ctx, input1, norm_deg=2):
assert input1.is_contiguous()
b, _, h, w = input1.size()
output = input1.new(b, 1, h, w).zero_()
channelnorm_cuda.forward(input1, output, norm_deg)
ctx.save_for_backward(input1, output)
ctx.norm_deg = norm_deg
return output
@staticmethod
def backward(ctx, grad_output):
input1, output = ctx.saved_tensors
grad_input1 = Variable(input1.new(input1.size()).zero_())
channelnorm_cuda.backward(input1, output, grad_output.data,
grad_input1.data, ctx.norm_deg)
return grad_input1, None
class ChannelNorm(Module):
def __init__(self, norm_deg=2):
super(ChannelNorm, self).__init__()
self.norm_deg = norm_deg
def forward(self, input1):
return ChannelNormFunction.apply(input1, self.norm_deg)

View File

@ -1,31 +0,0 @@
#include <torch/torch.h>
#include <ATen/ATen.h>
#include "channelnorm_kernel.cuh"
int channelnorm_cuda_forward(
at::Tensor& input1,
at::Tensor& output,
int norm_deg) {
channelnorm_kernel_forward(input1, output, norm_deg);
return 1;
}
int channelnorm_cuda_backward(
at::Tensor& input1,
at::Tensor& output,
at::Tensor& gradOutput,
at::Tensor& gradInput1,
int norm_deg) {
channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg);
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)");
m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)");
}

View File

@ -1,16 +0,0 @@
#pragma once
#include <ATen/ATen.h>
void channelnorm_kernel_forward(
at::Tensor& input1,
at::Tensor& output,
int norm_deg);
void channelnorm_kernel_backward(
at::Tensor& input1,
at::Tensor& output,
at::Tensor& gradOutput,
at::Tensor& gradInput1,
int norm_deg);

View File

@ -1,28 +0,0 @@
#!/usr/bin/env python3
import os
import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
cxx_args = ['-std=c++11']
nvcc_args = [
'-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60',
'-gencode', 'arch=compute_61,code=sm_61',
'-gencode', 'arch=compute_70,code=sm_70',
'-gencode', 'arch=compute_70,code=compute_70'
]
setup(
name='channelnorm_cuda',
ext_modules=[
CUDAExtension('channelnorm_cuda', [
'channelnorm_cuda.cc',
'channelnorm_kernel.cu'
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
],
cmdclass={
'build_ext': BuildExtension
})

View File

@ -1,61 +0,0 @@
import torch
from torch.nn.modules.module import Module
from torch.autograd import Function
import correlation_cuda
class CorrelationFunction(Function):
@staticmethod
def forward(ctx, input1, input2, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
ctx.save_for_backward(input1, input2)
ctx.pad_size = pad_size
ctx.kernel_size = kernel_size
ctx.max_displacement = max_displacement
ctx.stride1 = stride1
ctx.stride2 = stride2
ctx.corr_multiply = corr_multiply
with torch.cuda.device_of(input1):
rbot1 = input1.new()
rbot2 = input2.new()
output = input1.new()
correlation_cuda.forward(input1, input2, rbot1, rbot2, output,
ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply)
return output
@staticmethod
def backward(ctx, grad_output):
input1, input2 = ctx.saved_tensors
with torch.cuda.device_of(input1):
rbot1 = input1.new()
rbot2 = input2.new()
grad_input1 = input1.new()
grad_input2 = input2.new()
correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply)
return grad_input1, grad_input2, None, None, None, None, None, None
class Correlation(Module):
def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
super(Correlation, self).__init__()
self.pad_size = pad_size
self.kernel_size = kernel_size
self.max_displacement = max_displacement
self.stride1 = stride1
self.stride2 = stride2
self.corr_multiply = corr_multiply
def forward(self, input1, input2):
result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply)
return result

View File

@ -1,173 +0,0 @@
#include <torch/torch.h>
#include <ATen/ATen.h>
#include <ATen/Context.h>
#include <ATen/cuda/CUDAContext.h>
#include <stdio.h>
#include <iostream>
#include "correlation_cuda_kernel.cuh"
int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply)
{
int batchSize = input1.size(0);
int nInputChannels = input1.size(1);
int inputHeight = input1.size(2);
int inputWidth = input1.size(3);
int kernel_radius = (kernel_size - 1) / 2;
int border_radius = kernel_radius + max_displacement;
int paddedInputHeight = inputHeight + 2 * pad_size;
int paddedInputWidth = inputWidth + 2 * pad_size;
int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
int outputHeight = ceil(static_cast<float>(paddedInputHeight - 2 * border_radius) / static_cast<float>(stride1));
int outputwidth = ceil(static_cast<float>(paddedInputWidth - 2 * border_radius) / static_cast<float>(stride1));
rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
rInput1.fill_(0);
rInput2.fill_(0);
output.fill_(0);
int success = correlation_forward_cuda_kernel(
output,
output.size(0),
output.size(1),
output.size(2),
output.size(3),
output.stride(0),
output.stride(1),
output.stride(2),
output.stride(3),
input1,
input1.size(1),
input1.size(2),
input1.size(3),
input1.stride(0),
input1.stride(1),
input1.stride(2),
input1.stride(3),
input2,
input2.size(1),
input2.stride(0),
input2.stride(1),
input2.stride(2),
input2.stride(3),
rInput1,
rInput2,
pad_size,
kernel_size,
max_displacement,
stride1,
stride2,
corr_type_multiply,
at::cuda::getCurrentCUDAStream()
//at::globalContext().getCurrentCUDAStream()
);
//check for errors
if (!success) {
AT_ERROR("CUDA call failed");
}
return 1;
}
int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput,
at::Tensor& gradInput1, at::Tensor& gradInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply)
{
int batchSize = input1.size(0);
int nInputChannels = input1.size(1);
int paddedInputHeight = input1.size(2)+ 2 * pad_size;
int paddedInputWidth = input1.size(3)+ 2 * pad_size;
int height = input1.size(2);
int width = input1.size(3);
rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
gradInput1.resize_({batchSize, nInputChannels, height, width});
gradInput2.resize_({batchSize, nInputChannels, height, width});
rInput1.fill_(0);
rInput2.fill_(0);
gradInput1.fill_(0);
gradInput2.fill_(0);
int success = correlation_backward_cuda_kernel(gradOutput,
gradOutput.size(0),
gradOutput.size(1),
gradOutput.size(2),
gradOutput.size(3),
gradOutput.stride(0),
gradOutput.stride(1),
gradOutput.stride(2),
gradOutput.stride(3),
input1,
input1.size(1),
input1.size(2),
input1.size(3),
input1.stride(0),
input1.stride(1),
input1.stride(2),
input1.stride(3),
input2,
input2.stride(0),
input2.stride(1),
input2.stride(2),
input2.stride(3),
gradInput1,
gradInput1.stride(0),
gradInput1.stride(1),
gradInput1.stride(2),
gradInput1.stride(3),
gradInput2,
gradInput2.size(1),
gradInput2.stride(0),
gradInput2.stride(1),
gradInput2.stride(2),
gradInput2.stride(3),
rInput1,
rInput2,
pad_size,
kernel_size,
max_displacement,
stride1,
stride2,
corr_type_multiply,
at::cuda::getCurrentCUDAStream()
//at::globalContext().getCurrentCUDAStream()
);
if (!success) {
AT_ERROR("CUDA call failed");
}
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
}

View File

@ -1,91 +0,0 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/Context.h>
#include <cuda_runtime.h>
int correlation_forward_cuda_kernel(at::Tensor& output,
int ob,
int oc,
int oh,
int ow,
int osb,
int osc,
int osh,
int osw,
at::Tensor& input1,
int ic,
int ih,
int iw,
int isb,
int isc,
int ish,
int isw,
at::Tensor& input2,
int gc,
int gsb,
int gsc,
int gsh,
int gsw,
at::Tensor& rInput1,
at::Tensor& rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply,
cudaStream_t stream);
int correlation_backward_cuda_kernel(
at::Tensor& gradOutput,
int gob,
int goc,
int goh,
int gow,
int gosb,
int gosc,
int gosh,
int gosw,
at::Tensor& input1,
int ic,
int ih,
int iw,
int isb,
int isc,
int ish,
int isw,
at::Tensor& input2,
int gsb,
int gsc,
int gsh,
int gsw,
at::Tensor& gradInput1,
int gisb,
int gisc,
int gish,
int gisw,
at::Tensor& gradInput2,
int ggc,
int ggsb,
int ggsc,
int ggsh,
int ggsw,
at::Tensor& rInput1,
at::Tensor& rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply,
cudaStream_t stream);

View File

@ -1,29 +0,0 @@
#!/usr/bin/env python3
import os
import torch
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
cxx_args = ['-std=c++11']
nvcc_args = [
'-gencode', 'arch=compute_50,code=sm_50',
'-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60',
'-gencode', 'arch=compute_61,code=sm_61',
'-gencode', 'arch=compute_70,code=sm_70',
'-gencode', 'arch=compute_70,code=compute_70'
]
setup(
name='correlation_cuda',
ext_modules=[
CUDAExtension('correlation_cuda', [
'correlation_cuda.cc',
'correlation_cuda_kernel.cu'
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
],
cmdclass={
'build_ext': BuildExtension
})

View File

@ -1,49 +0,0 @@
from torch.nn.modules.module import Module
from torch.autograd import Function, Variable
import resample2d_cuda
class Resample2dFunction(Function):
@staticmethod
def forward(ctx, input1, input2, kernel_size=1, bilinear= True):
assert input1.is_contiguous()
assert input2.is_contiguous()
ctx.save_for_backward(input1, input2)
ctx.kernel_size = kernel_size
ctx.bilinear = bilinear
_, d, _, _ = input1.size()
b, _, h, w = input2.size()
output = input1.new(b, d, h, w).zero_()
resample2d_cuda.forward(input1, input2, output, kernel_size, bilinear)
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
assert grad_output.is_contiguous()
input1, input2 = ctx.saved_tensors
grad_input1 = Variable(input1.new(input1.size()).zero_())
grad_input2 = Variable(input1.new(input2.size()).zero_())
resample2d_cuda.backward(input1, input2, grad_output.data,
grad_input1.data, grad_input2.data,
ctx.kernel_size, ctx.bilinear)
return grad_input1, grad_input2, None, None
class Resample2d(Module):
def __init__(self, kernel_size=1, bilinear = True):
super(Resample2d, self).__init__()
self.kernel_size = kernel_size
self.bilinear = bilinear
def forward(self, input1, input2):
input1_c = input1.contiguous()
return Resample2dFunction.apply(input1_c, input2, self.kernel_size, self.bilinear)

View File

@ -1,32 +0,0 @@
#include <ATen/ATen.h>
#include <torch/torch.h>
#include "resample2d_kernel.cuh"
int resample2d_cuda_forward(
at::Tensor& input1,
at::Tensor& input2,
at::Tensor& output,
int kernel_size, bool bilinear) {
resample2d_kernel_forward(input1, input2, output, kernel_size, bilinear);
return 1;
}
int resample2d_cuda_backward(
at::Tensor& input1,
at::Tensor& input2,
at::Tensor& gradOutput,
at::Tensor& gradInput1,
at::Tensor& gradInput2,
int kernel_size, bool bilinear) {
resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size, bilinear);
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)");
m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)");
}

View File

@ -1,19 +0,0 @@
#pragma once
#include <ATen/ATen.h>
void resample2d_kernel_forward(
at::Tensor& input1,
at::Tensor& input2,
at::Tensor& output,
int kernel_size,
bool bilinear);
void resample2d_kernel_backward(
at::Tensor& input1,
at::Tensor& input2,
at::Tensor& gradOutput,
at::Tensor& gradInput1,
at::Tensor& gradInput2,
int kernel_size,
bool bilinear);

View File

@ -1,29 +0,0 @@
#!/usr/bin/env python3
import os
import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
cxx_args = ['-std=c++11']
nvcc_args = [
'-gencode', 'arch=compute_50,code=sm_50',
'-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60',
'-gencode', 'arch=compute_61,code=sm_61',
'-gencode', 'arch=compute_70,code=sm_70',
'-gencode', 'arch=compute_70,code=compute_70'
]
setup(
name='resample2d_cuda',
ext_modules=[
CUDAExtension('resample2d_cuda', [
'resample2d_cuda.cc',
'resample2d_kernel.cu'
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
],
cmdclass={
'build_ext': BuildExtension
})

View File

@ -49,7 +49,7 @@ class GANLoss(nn.Module):
# Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL
# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see data_scripts/compute_fdpl_perceptual_weights.py
# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see scripts/compute_fdpl_perceptual_weights.py
# In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs).
class FDPLLoss(nn.Module):
"""

View File

@ -1,71 +0,0 @@
# Author Masashi Kimura (Convergence Lab.)
import torch
from torch import optim
import math
class NovoGrad(optim.Optimizer):
def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(NovoGrad, self).__init__(params, defaults)
self._lr = lr
self._beta1 = betas[0]
self._beta2 = betas[1]
self._eps = eps
self._wd = weight_decay
self._grad_averaging = grad_averaging
self._momentum_initialized = False
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
if not self._momentum_initialized:
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('NovoGrad does not support sparse gradients')
v = torch.norm(grad)**2
m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data
state['step'] = 0
state['v'] = v
state['m'] = m
state['grad_ema'] = None
self._momentum_initialized = True
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
state['step'] += 1
step, v, m = state['step'], state['v'], state['m']
grad_ema = state['grad_ema']
grad = p.grad.data
g2 = torch.norm(grad)**2
grad_ema = g2 if grad_ema is None else grad_ema * \
self._beta2 + g2*(1. - self._beta2)
grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
if self._grad_averaging:
grad *= (1. - self._beta1)
g2 = torch.norm(grad)**2
v = self._beta2*v + (1. - self._beta2)*g2
m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd*p.data)
bias_correction1 = 1 - self._beta1 ** step
bias_correction2 = 1 - self._beta2 ** step
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
state['v'], state['m'] = v, m
state['grad_ema'] = grad_ema
p.data.add_(-step_size, m)
return loss

View File

@ -1,8 +1,6 @@
import torch.nn
from models.archs.SPSR_arch import ImageGradientNoPadding
from data.weight_scheduler import get_scheduler_for_opt
from utils.util import checkpoint
import torchvision.utils as utils
from utils.weight_scheduler import get_scheduler_for_opt
#from models.steps.recursive_gen_injectors import ImageFlowInjector
from models.steps.losses import extract_params_from_state

View File

@ -6,7 +6,6 @@ import torch
from apex import amp
from collections import OrderedDict
from .injectors import create_injector
from models.novograd import NovoGrad
from utils.util import recursively_detach
logger = logging.getLogger('base')

View File

@ -1,32 +0,0 @@
name: RRDB_ESRGAN_x4
suffix: ~ # add suffix to saved images
model: sr
distortion: sr
scale: 4
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
gpu_ids: [0]
datasets:
test_1: # the 1st test dataset
name: set5
mode: LQGT
dataroot_GT: ../datasets/val_set5/Set5
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
test_2: # the 2st test dataset
name: set14
mode: LQGT
dataroot_GT: ../datasets/val_set14/Set14
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
#### network structures
network_G:
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nb: 23
upscale: 4
#### path
path:
pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth

View File

@ -1,26 +0,0 @@
name: RRDB_ESRGAN_x4
suffix: ~ # add suffix to saved images
model: sr
distortion: sr
scale: 4
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
gpu_ids: [0]
datasets:
test_1: # the 1st test dataset
name: set14
mode: LQ
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
#### network structures
network_G:
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nb: 23
upscale: 4
#### path
path:
pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth

View File

@ -1,32 +0,0 @@
name: MSRGANx4
suffix: ~ # add suffix to saved images
model: sr
distortion: sr
scale: 4
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
gpu_ids: [0]
datasets:
test_1: # the 1st test dataset
name: set5
mode: LQGT
dataroot_GT: ../datasets/val_set5/Set5
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
test_2: # the 2st test dataset
name: set14
mode: LQGT
dataroot_GT: ../datasets/val_set14/Set14
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
#### network structures
network_G:
which_model_G: MSRResNet
in_nc: 3
out_nc: 3
nf: 64
nb: 16
upscale: 4
#### path
path:
pretrain_model_G: ../experiments/pretrained_models/MSRGANx4.pth

View File

@ -1,48 +0,0 @@
name: MSRResNetx4
suffix: ~ # add suffix to saved images
model: sr
distortion: sr
scale: 4
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
gpu_ids: [0]
datasets:
test_1: # the 1st test dataset
name: set5
mode: LQGT
dataroot_GT: ../datasets/val_set5/Set5
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
test_2: # the 2st test dataset
name: set14
mode: LQGT
dataroot_GT: ../datasets/val_set14/Set14
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
test_3:
name: bsd100
mode: LQGT
dataroot_GT: ../datasets/BSD/BSDS100
dataroot_LQ: ../datasets/BSD/BSDS100_bicLRx4
test_4:
name: urban100
mode: LQGT
dataroot_GT: ../datasets/urban100
dataroot_LQ: ../datasets/urban100_bicLRx4
test_5:
name: div2k100
mode: LQGT
dataroot_GT: ../datasets/DIV2K100/DIV2K_valid_HR
dataroot_LQ: ../datasets/DIV2K100/DIV2K_valid_bicLRx4
#### network structures
network_G:
which_model_G: MSRResNet
in_nc: 3
out_nc: 3
nf: 64
nb: 16
upscale: 4
#### path
path:
pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth

View File

@ -1,80 +0,0 @@
#### general settings
name: 002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new
use_tb_logger: true
model: video_base
distortion: sr
scale: 4
gpu_ids: [0,1,2,3,4,5,6,7]
#### datasets
datasets:
train:
name: REDS
mode: REDS
interval_list: [1]
random_reverse: false
border_mode: false
dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
cache_keys: ~
N_frames: 5
use_shuffle: true
n_workers: 3 # per GPU
batch_size: 32
target_size: 256
LQ_size: 64
use_flip: true
use_rot: true
color: RGB
val:
name: REDS4
mode: video_test
dataroot_GT: ../datasets/REDS4/GT
dataroot_LQ: ../datasets/REDS4/sharp_bicubic
cache_data: True
N_frames: 5
padding: new_info
#### network structures
network_G:
which_model_G: EDVR
nf: 64
nframes: 5
groups: 8
front_RBs: 5
back_RBs: 10
predeblur: false
HR_in: false
w_TSA: true
#### path
path:
pretrain_model_G: ../experiments/pretrained_models/EDVR_REDS_SR_M_woTSA.pth
strict_load: false
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 4e-4
lr_scheme: CosineAnnealingLR_Restart
beta1: 0.9
beta2: 0.99
niter: 600000
ft_tsa_only: 50000
warmup_iter: -1 # -1: no warm up
T_period: [50000, 100000, 150000, 150000, 150000]
restarts: [50000, 150000, 300000, 450000]
restart_weights: [1, 1, 1, 1]
eta_min: !!float 1e-7
pixel_criterion: cb
pixel_weight: 1.0
val_freq: !!float 5e3
manual_seed: 0
#### logger
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3

View File

@ -1,71 +0,0 @@
#### general settings
name: 001_EDVRwoTSA_scratch_lr4e-4_600k_REDS_LrCAR4S
use_tb_logger: true
model: video_base
distortion: sr
scale: 4
gpu_ids: [0,1,2,3,4,5,6,7]
#### datasets
datasets:
train:
name: REDS
mode: REDS
interval_list: [1]
random_reverse: false
border_mode: false
dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
cache_keys: ~
N_frames: 5
use_shuffle: true
n_workers: 3 # per GPU
batch_size: 32
target_size: 256
LQ_size: 64
use_flip: true
use_rot: true
color: RGB
#### network structures
network_G:
which_model_G: EDVR
nf: 64
nframes: 5
groups: 8
front_RBs: 5
back_RBs: 10
predeblur: false
HR_in: false
w_TSA: false
#### path
path:
pretrain_model_G: ~
strict_load: true
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 4e-4
lr_scheme: CosineAnnealingLR_Restart
beta1: 0.9
beta2: 0.99
niter: 600000
warmup_iter: -1 # -1: no warm up
T_period: [150000, 150000, 150000, 150000]
restarts: [150000, 300000, 450000]
restart_weights: [1, 1, 1]
eta_min: !!float 1e-7
pixel_criterion: cb
pixel_weight: 1.0
val_freq: !!float 5e3
manual_seed: 0
#### logger
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3

View File

@ -1,82 +0,0 @@
#### general settings
name: 003_RRDB_ESRGANx4_DIV2K
use_tb_logger: true
model: srgan
distortion: sr
scale: 4
gpu_ids: [0]
amp_opt_level: O1
#### datasets
datasets:
train:
name: DIV2K
mode: LQGT
dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub
dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4
use_shuffle: true
n_workers: 16 # per GPU
batch_size: 16
target_size: 128
use_flip: true
use_rot: true
color: RGB
val:
name: div2kval
mode: LQGT
dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr
dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic
#### network structures
network_G:
which_model_G: ResGen
nf: 256
network_D:
which_model_D: discriminator_resnet_passthrough
nf: 42
#### path
path:
pretrain_model_G: ~
strict_load: true
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 1e-4
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 1e-4
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
lr_scheme: MultiStepLR
niter: 400000
warmup_iter: -1 # no warm up
lr_steps: [50000, 100000, 200000, 300000]
lr_gamma: 0.5
mega_batch_factor: 1
pixel_criterion: l1
pixel_weight: !!float 1e-2
feature_criterion: l1
feature_weight: 1
feature_weight_decay: .98
feature_weight_decay_steps: 500
feature_weight_minimum: .1
gan_type: gan # gan | ragan
gan_weight: !!float 5e-3
D_update_ratio: 2
D_init_iters: 0
manual_seed: 10
val_freq: !!float 5e2
#### logger
logger:
print_freq: 50
save_checkpoint_freq: !!float 5e2

View File

@ -1,85 +0,0 @@
#### general settings
name: esrgan_res
use_tb_logger: true
model: srgan
distortion: sr
scale: 4
gpu_ids: [0]
amp_opt_level: O1
#### datasets
datasets:
train:
name: DIV2K
mode: LQGT
dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub
dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4
use_shuffle: true
n_workers: 0 # per GPU
batch_size: 24
target_size: 128
use_flip: true
use_rot: true
color: RGB
val:
name: div2kval
mode: LQGT
dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr
dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic
#### network structures
network_G:
which_model_G: ResGen
nf: 256
nb_denoiser: 2
nb_upsampler: 28
network_D:
which_model_D: discriminator_resnet_passthrough
nf: 42
#### path
path:
#pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth
#pretrain_model_D: ~
strict_load: true
resume_state: ../experiments/esrgan_res/training_state/15500.state
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 1e-4
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 1e-4
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
lr_scheme: MultiStepLR
niter: 400000
warmup_iter: -1 # no warm up
lr_steps: [20000, 40000, 50000, 60000]
lr_gamma: 0.5
mega_batch_factor: 2
pixel_criterion: l1
pixel_weight: !!float 1e-2
feature_criterion: l1
feature_weight: 1
feature_weight_decay: 1
feature_weight_decay_steps: 500
feature_weight_minimum: 1
gan_type: gan # gan | ragan
gan_weight: !!float 1e-2
D_update_ratio: 2
D_init_iters: -1
manual_seed: 10
val_freq: !!float 5e2
#### logger
logger:
print_freq: 50
save_checkpoint_freq: !!float 5e2

View File

@ -1,85 +0,0 @@
# Not exactly the same as SRGAN in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
# With 16 Residual blocks w/o BN
#### general settings
name: 002_SRGANx4_MSRResNetx4Ini_DIV2K
use_tb_logger: true
model: srgan
distortion: sr
scale: 4
gpu_ids: [1]
#### datasets
datasets:
train:
name: DIV2K
mode: LQGT
dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
use_shuffle: true
n_workers: 6 # per GPU
batch_size: 16
target_size: 128
use_flip: true
use_rot: true
color: RGB
val:
name: val_set14
mode: LQGT
dataroot_GT: ../datasets/val_set14/Set14
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
#### network structures
network_G:
which_model_G: MSRResNet
in_nc: 3
out_nc: 3
nf: 64
nb: 16
upscale: 4
network_D:
which_model_D: discriminator_vgg_128
in_nc: 3
nf: 64
#### path
path:
pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth
strict_load: true
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 1e-4
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 1e-4
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
lr_scheme: MultiStepLR
niter: 400000
warmup_iter: -1 # no warm up
lr_steps: [50000, 100000, 200000, 300000]
lr_gamma: 0.5
pixel_criterion: l1
pixel_weight: !!float 1e-2
feature_criterion: l1
feature_weight: 1
gan_type: gan # gan | ragan
gan_weight: !!float 5e-3
D_update_ratio: 1
D_init_iters: 0
manual_seed: 10
val_freq: !!float 5e3
#### logger
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3

View File

@ -1,70 +0,0 @@
# Not exactly the same as SRResNet in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
# With 16 Residual blocks w/o BN
#### general settings
name: 001_MSRResNetx4_scratch_DIV2K
use_tb_logger: true
model: sr
distortion: sr
scale: 4
gpu_ids: [0]
#### datasets
datasets:
train:
name: DIV2K
mode: LQGT
dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
use_shuffle: true
n_workers: 6 # per GPU
batch_size: 16
target_size: 128
use_flip: true
use_rot: true
color: RGB
val:
name: val_set5
mode: LQGT
dataroot_GT: ../datasets/val_set5/Set5
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
#### network structures
network_G:
which_model_G: MSRResNet
in_nc: 3
out_nc: 3
nf: 64
nb: 16
upscale: 4
#### path
path:
pretrain_model_G: ~
strict_load: true
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 2e-4
lr_scheme: CosineAnnealingLR_Restart
beta1: 0.9
beta2: 0.99
niter: 1000000
warmup_iter: -1 # no warm up
T_period: [250000, 250000, 250000, 250000]
restarts: [250000, 500000, 750000]
restart_weights: [1, 1, 1]
eta_min: !!float 1e-7
pixel_criterion: l1
pixel_weight: 1.0
manual_seed: 10
val_freq: !!float 5e3
#### logger
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3

View File

@ -11,7 +11,7 @@ import torchvision.transforms.functional as F
from PIL import Image
from tqdm import tqdm
import options.options as option
from utils import options as option
import utils.util as util
from data import create_dataloader
from models import create_model

View File

@ -1,10 +0,0 @@
# single GPU training (image SR)
python train.py -opt options/train/train_SRResNet.yml
python train.py -opt options/train/train_SRGAN.yml
python train.py -opt options/train/train_ESRGAN.yml
# distributed training (video SR)
# 8 GPUs
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_woTSA_M.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_M.yml --launcher pytorch

View File

@ -1,20 +0,0 @@
function [im_h] = backprojection(im_h, im_l, maxIter)
[row_l, col_l,~] = size(im_l);
[row_h, col_h,~] = size(im_h);
p = fspecial('gaussian', 5, 1);
p = p.^2;
p = p./sum(p(:));
im_l = double(im_l);
im_h = double(im_h);
for ii = 1:maxIter
im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
im_diff = im_l - im_l_s;
im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
end

View File

@ -1,22 +0,0 @@
clear; close all; clc;
LR_folder = './LR'; % LR
preout_folder = './results'; % pre output
save_folder = './results_20bp';
filepaths = dir(fullfile(preout_folder, '*.png'));
max_iter = 20;
if ~ exist(save_folder, 'dir')
mkdir(save_folder);
end
for idx_im = 1:length(filepaths)
fprintf([num2str(idx_im) '\n']);
im_name = filepaths(idx_im).name;
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
im_out = im2double(imread(fullfile(preout_folder, im_name)));
%tic
im_out = backprojection(im_out, im_LR, max_iter);
%toc
imwrite(im_out, fullfile(save_folder, im_name));
end

View File

@ -1,25 +0,0 @@
clear; close all; clc;
LR_folder = './LR'; % LR
preout_folder = './results'; % pre output
save_folder = './results_20if';
filepaths = dir(fullfile(preout_folder, '*.png'));
max_iter = 20;
if ~ exist(save_folder, 'dir')
mkdir(save_folder);
end
for idx_im = 1:length(filepaths)
fprintf([num2str(idx_im) '\n']);
im_name = filepaths(idx_im).name;
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
im_out = im2double(imread(fullfile(preout_folder, im_name)));
J = imresize(im_LR,4,'bicubic');
%tic
for m = 1:max_iter
im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic'));
end
%toc
imwrite(im_out, fullfile(save_folder, im_name));
end

View File

@ -1,14 +1,10 @@
import torch
import os
from PIL import Image
import numpy as np
import options.options as option
from utils import options as option
from data import create_dataloader, create_dataset
import math
from tqdm import tqdm
from torchvision import transforms
from utils.fdpl_util import dct_2d, extract_patches_2d
import random
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from utils.colors import rgb2ycbcr

View File

@ -1,27 +0,0 @@
import os.path as osp
import sys
import torch
try:
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
import models.archs.SRResNet_arch as SRResNet_arch
except ImportError:
pass
pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth')
crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
crt_net = crt_model.state_dict()
for k, v in crt_net.items():
if k in pretrained_net and 'upconv1' not in k:
crt_net[k] = pretrained_net[k]
print('replace ... ', k)
# x4 -> x3
crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2
torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')

View File

@ -2,21 +2,14 @@ import os.path as osp
import logging
import time
import argparse
from collections import OrderedDict
import os
import options.options as option
from utils import options as option
import utils.util as util
from data.util import bgr2ycbcr
import models.archs.SwitchedResidualGenerator_arch as srg
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
from switched_conv.switched_conv import compute_attention_specificity
from data import create_dataset, create_dataloader
from models import create_model
from tqdm import tqdm
import torch
import models.networks as networks
import shutil
import torchvision

View File

@ -5,10 +5,8 @@ import math
import argparse
import random
import torch
import options.options as option
from utils import util
from utils import util, options as option
from data import create_dataloader, create_dataset
from time import time
from tqdm import tqdm
from skimage import io

View File

@ -1,9 +0,0 @@
rm gen/*
rm hr/*
rm lr/*
rm pix/*
rm ref/*
rm genlr/*
rm genmr/*
rm lr_precorrupt/*
rm ref/*

View File

@ -1,21 +1,5 @@
import os.path as osp
import logging
import time
import argparse
from collections import OrderedDict
import os
import options.options as option
import utils.util as util
from data.util import bgr2ycbcr
import models.archs.SwitchedResidualGenerator_arch as srg
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
from switched_conv.switched_conv import compute_attention_specificity
from data import create_dataset, create_dataloader
from models import create_model
from tqdm import tqdm
import torch
import models.networks as networks
class CheckpointFunction(torch.autograd.Function):
@staticmethod
@ -39,7 +23,7 @@ class CheckpointFunction(torch.autograd.Function):
input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)
return (None, None) + input_grads
from models.archs.arch_util import ConvGnSilu, UpconvBlock
from models.archs.arch_util import ConvGnSilu
import torch.nn as nn
if __name__ == "__main__":
model = nn.Sequential(ConvGnSilu(3, 64, 3, norm=False),

View File

@ -3,16 +3,14 @@ import math
import argparse
import random
import logging
import shutil
from tqdm import tqdm
import torch
from data.data_sampler import DistIterSampler
import options.options as option
from utils import util
from utils import util, options as option
from data import create_dataloader, create_dataset
from models import create_model
from models.ExtensibleTrainer import ExtensibleTrainer
from time import time
@ -159,7 +157,7 @@ def main():
assert train_loader is not None
#### create model
model = create_model(opt)
model = ExtensibleTrainer(opt)
#### resume training
if resume_state:

View File

@ -3,16 +3,14 @@ import math
import argparse
import random
import logging
import shutil
from tqdm import tqdm
import torch
from data.data_sampler import DistIterSampler
import options.options as option
from utils import util
from models.ExtensibleTrainer import ExtensibleTrainer
from utils import util, options as option
from data import create_dataloader, create_dataset
from models import create_model
from time import time
@ -159,7 +157,7 @@ def main():
assert train_loader is not None
#### create model
model = create_model(opt)
model = ExtensibleTrainer(opt)
#### resume training
if resume_state:

View File

@ -1,7 +1,7 @@
# Tool that can be used to add a new layer into an existing model save file. Primarily useful for "progressive"
# models which can be trained piecemeal.
import options.options as option
from utils import options as option
from models import create_model
import torch
import os

View File

@ -1,7 +1,7 @@
import argparse
import functools
import torch
import options.options as option
from utils import options as option
from models.networks import define_G