forked from mrq/DL-Art-School
Codebase cleanup
Removed a lot of legacy stuff I have no intent on using again. Plan is to shape this repo into something more extensible (get it? hah!)
This commit is contained in:
parent
e620fc05ba
commit
24792bdb4f
|
@ -1,4 +0,0 @@
|
||||||
[style]
|
|
||||||
BASED_ON_STYLE = pep8
|
|
||||||
COLUMN_LIMIT = 100
|
|
||||||
SPLIT_BEFORE_NAMED_ASSIGNS = false
|
|
|
@ -1,124 +0,0 @@
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import lmdb
|
|
||||||
import torch
|
|
||||||
import torch.utils.data as data
|
|
||||||
import data.util as util
|
|
||||||
from PIL import Image
|
|
||||||
from io import BytesIO
|
|
||||||
import torchvision.transforms.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class DownsampleDataset(data.Dataset):
|
|
||||||
"""
|
|
||||||
Reads an unpaired HQ and LQ image. Clips both images to the expected input sizes of the model. Produces a
|
|
||||||
downsampled LQ image from the HQ image and feeds that as well.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, opt):
|
|
||||||
super(DownsampleDataset, self).__init__()
|
|
||||||
self.opt = opt
|
|
||||||
self.data_type = self.opt['data_type']
|
|
||||||
self.paths_LQ, self.paths_GT = None, None
|
|
||||||
self.sizes_LQ, self.sizes_GT = None, None
|
|
||||||
self.LQ_env, self.GT_env = None, None # environments for lmdb
|
|
||||||
self.doCrop = self.opt['doCrop']
|
|
||||||
|
|
||||||
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
|
||||||
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
|
||||||
|
|
||||||
self.data_sz_mismatch_ok = opt['mismatched_Data_OK']
|
|
||||||
assert self.paths_GT, 'Error: GT path is empty.'
|
|
||||||
assert self.paths_LQ, 'LQ is required for downsampling.'
|
|
||||||
if not self.data_sz_mismatch_ok:
|
|
||||||
assert len(self.paths_LQ) == len(
|
|
||||||
self.paths_GT
|
|
||||||
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
|
|
||||||
len(self.paths_LQ), len(self.paths_GT))
|
|
||||||
self.random_scale_list = [1]
|
|
||||||
|
|
||||||
def _init_lmdb(self):
|
|
||||||
# https://github.com/chainer/chainermn/issues/129
|
|
||||||
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
|
|
||||||
meminit=False)
|
|
||||||
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
|
||||||
meminit=False)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
|
||||||
self._init_lmdb()
|
|
||||||
scale = self.opt['scale']
|
|
||||||
GT_size = self.opt['target_size'] * scale
|
|
||||||
|
|
||||||
# get GT image
|
|
||||||
GT_path = self.paths_GT[index % len(self.paths_GT)]
|
|
||||||
resolution = [int(s) for s in self.sizes_GT[index].split('_')
|
|
||||||
] if self.data_type == 'lmdb' else None
|
|
||||||
img_GT = util.read_img(self.GT_env, GT_path, resolution)
|
|
||||||
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
|
|
||||||
img_GT = util.modcrop(img_GT, scale)
|
|
||||||
if self.opt['color']: # change color space if necessary
|
|
||||||
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
|
|
||||||
|
|
||||||
# get LQ image
|
|
||||||
LQ_path = self.paths_LQ[index % len(self.paths_LQ)]
|
|
||||||
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
|
|
||||||
] if self.data_type == 'lmdb' else None
|
|
||||||
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
|
|
||||||
|
|
||||||
if self.opt['phase'] == 'train':
|
|
||||||
H, W, _ = img_GT.shape
|
|
||||||
assert H >= GT_size and W >= GT_size
|
|
||||||
|
|
||||||
H, W, C = img_LQ.shape
|
|
||||||
LQ_size = GT_size // scale
|
|
||||||
|
|
||||||
if self.doCrop:
|
|
||||||
# randomly crop
|
|
||||||
rnd_h = random.randint(0, max(0, H - LQ_size))
|
|
||||||
rnd_w = random.randint(0, max(0, W - LQ_size))
|
|
||||||
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
|
||||||
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
|
|
||||||
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
|
||||||
else:
|
|
||||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
|
||||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
|
||||||
|
|
||||||
# augmentation - flip, rotate
|
|
||||||
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
|
|
||||||
self.opt['use_rot'])
|
|
||||||
|
|
||||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
|
||||||
if img_GT.shape[2] == 3:
|
|
||||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
|
|
||||||
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
|
|
||||||
|
|
||||||
# HQ needs to go to a PIL image to perform the compression-artifact transformation.
|
|
||||||
H, W, _ = img_GT.shape
|
|
||||||
img_GT = (img_GT * 255).astype(np.uint8)
|
|
||||||
img_GT = Image.fromarray(img_GT)
|
|
||||||
if self.opt['use_compression_artifacts']:
|
|
||||||
qf = random.randrange(15, 100)
|
|
||||||
corruption_buffer = BytesIO()
|
|
||||||
img_GT.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
|
|
||||||
corruption_buffer.seek(0)
|
|
||||||
img_GT = Image.open(corruption_buffer)
|
|
||||||
# Generate a downsampled image from HQ for feature and PIX losses.
|
|
||||||
img_Downsampled = F.resize(img_GT, (int(H / scale), int(W / scale)))
|
|
||||||
|
|
||||||
img_GT = F.to_tensor(img_GT)
|
|
||||||
img_Downsampled = F.to_tensor(img_Downsampled)
|
|
||||||
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
|
|
||||||
|
|
||||||
# This may seem really messed up, but let me explain:
|
|
||||||
# The goal is to re-use existing code as much as possible. SRGAN_model was coded to supersample, not downsample,
|
|
||||||
# but it can be retrofitted. To do so, we need to "trick" it. In this case the "input" is the HQ image and the
|
|
||||||
# "output" is the LQ image. SRGAN_model will be using a Generator and a Discriminator which already know this,
|
|
||||||
# we just need to trick its logic into following this rules.
|
|
||||||
# Do this by setting LQ(which is the input into the models)=img_GT and GT(which is the expected outpuut)=img_LQ.
|
|
||||||
# PIX is used as a reference for the pixel loss. Use the manually downsampled image for this.
|
|
||||||
return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return max(len(self.paths_GT), len(self.paths_LQ))
|
|
|
@ -1,239 +0,0 @@
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import lmdb
|
|
||||||
import torch
|
|
||||||
import torch.utils.data as data
|
|
||||||
import data.util as util
|
|
||||||
from PIL import Image, ImageOps
|
|
||||||
from io import BytesIO
|
|
||||||
import torchvision.transforms.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class LQGTDataset(data.Dataset):
|
|
||||||
"""
|
|
||||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
|
|
||||||
If only GT images are provided, generate LQ images on-the-fly.
|
|
||||||
"""
|
|
||||||
def get_lq_path(self, i):
|
|
||||||
which_lq = random.randint(0, len(self.paths_LQ)-1)
|
|
||||||
return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])]
|
|
||||||
|
|
||||||
def __init__(self, opt):
|
|
||||||
super(LQGTDataset, self).__init__()
|
|
||||||
self.opt = opt
|
|
||||||
self.data_type = self.opt['data_type']
|
|
||||||
self.paths_LQ, self.paths_GT = None, None
|
|
||||||
self.sizes_LQ, self.sizes_GT = None, None
|
|
||||||
self.paths_PIX, self.sizes_PIX = None, None
|
|
||||||
self.paths_GAN, self.sizes_GAN = None, None
|
|
||||||
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs
|
|
||||||
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
|
|
||||||
|
|
||||||
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights'])
|
|
||||||
if 'dataroot_LQ' in opt.keys():
|
|
||||||
self.paths_LQ = []
|
|
||||||
if isinstance(opt['dataroot_LQ'], list):
|
|
||||||
# Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and
|
|
||||||
# we want the model to learn them all.
|
|
||||||
for dr_lq in opt['dataroot_LQ']:
|
|
||||||
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq)
|
|
||||||
self.paths_LQ.append(lq_path)
|
|
||||||
else:
|
|
||||||
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
|
||||||
self.paths_LQ.append(lq_path)
|
|
||||||
self.doCrop = opt['doCrop']
|
|
||||||
if 'dataroot_PIX' in opt.keys():
|
|
||||||
self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX'])
|
|
||||||
# dataroot_GAN is an alternative source of LR images specifically for use in computing the GAN loss, where
|
|
||||||
# LR and HR do not need to be paired.
|
|
||||||
if 'dataroot_GAN' in opt.keys():
|
|
||||||
self.paths_GAN, self.sizes_GAN = util.get_image_paths(self.data_type, opt['dataroot_GAN'])
|
|
||||||
print('loaded %i images for use in training GAN only.' % (self.sizes_GAN,))
|
|
||||||
|
|
||||||
assert self.paths_GT, 'Error: GT path is empty.'
|
|
||||||
self.random_scale_list = [1]
|
|
||||||
|
|
||||||
def _init_lmdb(self):
|
|
||||||
# https://github.com/chainer/chainermn/issues/129
|
|
||||||
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
|
|
||||||
meminit=False)
|
|
||||||
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
|
||||||
meminit=False)
|
|
||||||
if 'dataroot_PIX' in self.opt.keys():
|
|
||||||
self.PIX_env = lmdb.open(self.opt['dataroot_PIX'], readonly=True, lock=False, readahead=False,
|
|
||||||
meminit=False)
|
|
||||||
|
|
||||||
def motion_blur(self, image, size, angle):
|
|
||||||
k = np.zeros((size, size), dtype=np.float32)
|
|
||||||
k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32)
|
|
||||||
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size))
|
|
||||||
k = k * (1.0 / np.sum(k))
|
|
||||||
return cv2.filter2D(image, -1, k)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
|
||||||
self._init_lmdb()
|
|
||||||
GT_path, LQ_path = None, None
|
|
||||||
scale = self.opt['scale']
|
|
||||||
GT_size = self.opt['target_size']
|
|
||||||
|
|
||||||
# get GT image
|
|
||||||
GT_path = self.paths_GT[index % len(self.paths_GT)]
|
|
||||||
resolution = [int(s) for s in self.sizes_GT[index].split('_')
|
|
||||||
] if self.data_type == 'lmdb' else None
|
|
||||||
img_GT = util.read_img(self.GT_env, GT_path, resolution)
|
|
||||||
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
|
|
||||||
img_GT = util.modcrop(img_GT, scale)
|
|
||||||
if self.opt['color']: # change color space if necessary
|
|
||||||
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
|
|
||||||
|
|
||||||
# get the pix image
|
|
||||||
if self.paths_PIX is not None:
|
|
||||||
PIX_path = self.paths_PIX[index % len(self.paths_PIX)]
|
|
||||||
img_PIX = util.read_img(self.PIX_env, PIX_path, resolution)
|
|
||||||
if self.opt['color']: # change color space if necessary
|
|
||||||
img_PIX = util.channel_convert(img_PIX.shape[2], self.opt['color'], [img_PIX])[0]
|
|
||||||
else:
|
|
||||||
img_PIX = img_GT
|
|
||||||
|
|
||||||
# get LQ image
|
|
||||||
if self.paths_LQ:
|
|
||||||
LQ_path = self.get_lq_path(index)
|
|
||||||
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
|
|
||||||
] if self.data_type == 'lmdb' else None
|
|
||||||
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
|
|
||||||
else: # down-sampling on-the-fly
|
|
||||||
# randomly scale during training
|
|
||||||
if self.opt['phase'] == 'train':
|
|
||||||
random_scale = random.choice(self.random_scale_list)
|
|
||||||
H_s, W_s, _ = img_GT.shape
|
|
||||||
|
|
||||||
def _mod(n, random_scale, scale, thres):
|
|
||||||
rlt = int(n * random_scale)
|
|
||||||
rlt = (rlt // scale) * scale
|
|
||||||
return thres if rlt < thres else rlt
|
|
||||||
|
|
||||||
H_s = _mod(H_s, random_scale, scale, GT_size)
|
|
||||||
W_s = _mod(W_s, random_scale, scale, GT_size)
|
|
||||||
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
|
|
||||||
if img_GT.ndim == 2:
|
|
||||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
|
|
||||||
|
|
||||||
H, W, _ = img_GT.shape
|
|
||||||
|
|
||||||
# using matlab imresize
|
|
||||||
if scale == 1:
|
|
||||||
img_LQ = img_GT
|
|
||||||
else:
|
|
||||||
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
|
||||||
if img_LQ.ndim == 2:
|
|
||||||
img_LQ = np.expand_dims(img_LQ, axis=2)
|
|
||||||
|
|
||||||
img_GAN = None
|
|
||||||
if self.paths_GAN:
|
|
||||||
GAN_path = self.paths_GAN[index % self.sizes_GAN]
|
|
||||||
img_GAN = util.read_img(self.LQ_env, GAN_path)
|
|
||||||
|
|
||||||
# Enforce force_resize constraints.
|
|
||||||
h, w, _ = img_LQ.shape
|
|
||||||
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
|
|
||||||
h, w = (h - h % self.force_multiple), (w - w % self.force_multiple)
|
|
||||||
img_LQ = img_LQ[:h, :w, :]
|
|
||||||
h *= scale
|
|
||||||
w *= scale
|
|
||||||
img_GT = img_GT[:h, :w, :]
|
|
||||||
img_PIX = img_PIX[:h, :w, :]
|
|
||||||
|
|
||||||
if self.opt['phase'] == 'train':
|
|
||||||
H, W, _ = img_GT.shape
|
|
||||||
assert H >= GT_size and W >= GT_size
|
|
||||||
|
|
||||||
H, W, C = img_LQ.shape
|
|
||||||
LQ_size = GT_size // scale
|
|
||||||
|
|
||||||
if self.doCrop:
|
|
||||||
# randomly crop
|
|
||||||
rnd_h = random.randint(0, max(0, H - LQ_size))
|
|
||||||
rnd_w = random.randint(0, max(0, W - LQ_size))
|
|
||||||
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
|
||||||
if img_GAN is not None:
|
|
||||||
img_GAN = img_GAN[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
|
||||||
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
|
|
||||||
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
|
||||||
img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
|
||||||
else:
|
|
||||||
if img_LQ.shape[0] != LQ_size:
|
|
||||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
|
||||||
if img_GAN is not None:
|
|
||||||
img_GAN = cv2.resize(img_GAN, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
|
||||||
if img_GT.shape[0] != GT_size:
|
|
||||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
|
||||||
if img_PIX.shape[0] != GT_size:
|
|
||||||
img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
|
||||||
|
|
||||||
if 'doResizeLoss' in self.opt.keys() and self.opt['doResizeLoss']:
|
|
||||||
r = random.randrange(0, 10)
|
|
||||||
if r > 5:
|
|
||||||
img_LQ = cv2.resize(img_LQ, (int(LQ_size/2), int(LQ_size/2)), interpolation=cv2.INTER_LINEAR)
|
|
||||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
|
||||||
|
|
||||||
# augmentation - flip, rotate
|
|
||||||
img_LQ, img_GT, img_PIX = util.augment([img_LQ, img_GT, img_PIX], self.opt['use_flip'],
|
|
||||||
self.opt['use_rot'])
|
|
||||||
|
|
||||||
if self.opt['use_blurring']:
|
|
||||||
# Pick randomly between gaussian, motion, or no blur.
|
|
||||||
blur_det = random.randint(0, 100)
|
|
||||||
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
|
|
||||||
if blur_det < 40:
|
|
||||||
blur_sig = int(random.randrange(0, blur_magnitude))
|
|
||||||
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
|
|
||||||
elif blur_det < 70:
|
|
||||||
img_LQ = self.motion_blur(img_LQ, random.randrange(1, blur_magnitude * 3), random.randint(0, 360))
|
|
||||||
|
|
||||||
|
|
||||||
if self.opt['color']: # change color space if necessary
|
|
||||||
img_LQ = util.channel_convert(C, self.opt['color'],
|
|
||||||
[img_LQ])[0] # TODO during val no definition
|
|
||||||
|
|
||||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
|
||||||
if img_GT.shape[2] == 3:
|
|
||||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
|
|
||||||
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
|
|
||||||
if img_GAN is not None:
|
|
||||||
img_GAN = cv2.cvtColor(img_GAN, cv2.COLOR_BGR2RGB)
|
|
||||||
img_PIX = cv2.cvtColor(img_PIX, cv2.COLOR_BGR2RGB)
|
|
||||||
|
|
||||||
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
|
|
||||||
img_LQ = (img_LQ * 255).astype(np.uint8)
|
|
||||||
img_LQ = Image.fromarray(img_LQ)
|
|
||||||
if self.opt['use_compression_artifacts'] and random.random() > .25:
|
|
||||||
qf = random.randrange(10, 70)
|
|
||||||
corruption_buffer = BytesIO()
|
|
||||||
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
|
|
||||||
corruption_buffer.seek(0)
|
|
||||||
img_LQ = Image.open(corruption_buffer)
|
|
||||||
|
|
||||||
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
|
|
||||||
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
|
|
||||||
|
|
||||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
|
||||||
img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float()
|
|
||||||
img_LQ = F.to_tensor(img_LQ)
|
|
||||||
if img_GAN is not None:
|
|
||||||
img_GAN = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GAN, (2, 0, 1)))).float()
|
|
||||||
|
|
||||||
if 'lq_noise' in self.opt.keys():
|
|
||||||
lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255
|
|
||||||
img_LQ += lq_noise
|
|
||||||
|
|
||||||
if LQ_path is None:
|
|
||||||
LQ_path = GT_path
|
|
||||||
d = {'LQ': img_LQ, 'GT': img_GT, 'ref': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
|
||||||
if img_GAN is not None:
|
|
||||||
d['GAN'] = img_GAN
|
|
||||||
return d
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.paths_GT)
|
|
|
@ -1,71 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
import lmdb
|
|
||||||
import torch
|
|
||||||
import torch.utils.data as data
|
|
||||||
import data.util as util
|
|
||||||
import torchvision.transforms.functional as F
|
|
||||||
from PIL import Image
|
|
||||||
import os.path as osp
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
|
|
||||||
class LQDataset(data.Dataset):
|
|
||||||
'''Read LQ images only in the test phase.'''
|
|
||||||
|
|
||||||
def __init__(self, opt):
|
|
||||||
super(LQDataset, self).__init__()
|
|
||||||
self.opt = opt
|
|
||||||
self.data_type = self.opt['data_type']
|
|
||||||
if 'start_at' in self.opt.keys():
|
|
||||||
self.start_at = self.opt['start_at']
|
|
||||||
else:
|
|
||||||
self.start_at = 0
|
|
||||||
self.vertical_splits = self.opt['vertical_splits']
|
|
||||||
self.paths_LQ, self.paths_GT = None, None
|
|
||||||
self.LQ_env = None # environment for lmdb
|
|
||||||
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
|
|
||||||
|
|
||||||
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
|
||||||
self.paths_LQ = self.paths_LQ[self.start_at:]
|
|
||||||
assert self.paths_LQ, 'Error: LQ paths are empty.'
|
|
||||||
|
|
||||||
def _init_lmdb(self):
|
|
||||||
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
|
||||||
meminit=False)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
if self.data_type == 'lmdb' and self.LQ_env is None:
|
|
||||||
self._init_lmdb()
|
|
||||||
if self.vertical_splits > 0:
|
|
||||||
actual_index = int(index / self.vertical_splits)
|
|
||||||
else:
|
|
||||||
actual_index = index
|
|
||||||
|
|
||||||
# get LQ image
|
|
||||||
LQ_path = self.paths_LQ[actual_index]
|
|
||||||
img_LQ = Image.open(LQ_path)
|
|
||||||
if self.vertical_splits > 0:
|
|
||||||
w, h = img_LQ.size
|
|
||||||
split_index = (index % self.vertical_splits)
|
|
||||||
w_per_split = int(w / self.vertical_splits)
|
|
||||||
left = w_per_split * split_index
|
|
||||||
img_LQ = F.crop(img_LQ, 0, left, h, w_per_split)
|
|
||||||
|
|
||||||
# Enforce force_resize constraints.
|
|
||||||
h, w = img_LQ.size
|
|
||||||
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
|
|
||||||
h, w = (w - w % self.force_multiple), (h - h % self.force_multiple)
|
|
||||||
img_LQ = img_LQ.resize((w, h))
|
|
||||||
|
|
||||||
img_LQ = F.to_tensor(img_LQ)
|
|
||||||
|
|
||||||
img_name = osp.splitext(osp.basename(LQ_path))[0]
|
|
||||||
LQ_path = LQ_path.replace(img_name, img_name + "_%i" % (index % self.vertical_splits))
|
|
||||||
|
|
||||||
return {'LQ': img_LQ, 'LQ_path': LQ_path}
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
if self.vertical_splits > 0:
|
|
||||||
return len(self.paths_LQ) * self.vertical_splits
|
|
||||||
else:
|
|
||||||
return len(self.paths_LQ)
|
|
|
@ -29,14 +29,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
||||||
def create_dataset(dataset_opt):
|
def create_dataset(dataset_opt):
|
||||||
mode = dataset_opt['mode']
|
mode = dataset_opt['mode']
|
||||||
# datasets for image restoration
|
# datasets for image restoration
|
||||||
if mode == 'LQ':
|
if mode == 'fullimage':
|
||||||
from data.LQ_dataset import LQDataset as D
|
|
||||||
elif mode == 'LQGT':
|
|
||||||
from data.LQGT_dataset import LQGTDataset as D
|
|
||||||
# datasets for image corruption
|
|
||||||
elif mode == 'downsample':
|
|
||||||
from data.Downsample_dataset import DownsampleDataset as D
|
|
||||||
elif mode == 'fullimage':
|
|
||||||
from data.full_image_dataset import FullImageDataset as D
|
from data.full_image_dataset import FullImageDataset as D
|
||||||
elif mode == 'single_image_extensible':
|
elif mode == 'single_image_extensible':
|
||||||
from data.single_image_dataset import SingleImageDataset as D
|
from data.single_image_dataset import SingleImageDataset as D
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,171 +0,0 @@
|
||||||
import logging
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
|
||||||
import models.networks as networks
|
|
||||||
import models.lr_scheduler as lr_scheduler
|
|
||||||
from .base_model import BaseModel
|
|
||||||
from models.loss import CharbonnierLoss
|
|
||||||
from apex import amp
|
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
|
||||||
|
|
||||||
|
|
||||||
class SRModel(BaseModel):
|
|
||||||
def __init__(self, opt):
|
|
||||||
super(SRModel, self).__init__(opt)
|
|
||||||
|
|
||||||
if opt['dist']:
|
|
||||||
self.rank = torch.distributed.get_rank()
|
|
||||||
else:
|
|
||||||
self.rank = -1 # non dist training
|
|
||||||
train_opt = opt['train']
|
|
||||||
|
|
||||||
# define network and load pretrained models
|
|
||||||
self.netG = amp.initialize(networks.define_G(opt).to(self.device), opt_level=self.amp_level)
|
|
||||||
if opt['dist']:
|
|
||||||
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
|
||||||
elif opt['gpu_ids'] is not None:
|
|
||||||
self.netG = DataParallel(self.netG)
|
|
||||||
# print network
|
|
||||||
self.print_network()
|
|
||||||
self.load()
|
|
||||||
|
|
||||||
if self.is_train:
|
|
||||||
self.netG.train()
|
|
||||||
|
|
||||||
# loss
|
|
||||||
loss_type = train_opt['pixel_criterion']
|
|
||||||
if loss_type == 'l1':
|
|
||||||
self.cri_pix = nn.L1Loss().to(self.device)
|
|
||||||
elif loss_type == 'l2':
|
|
||||||
self.cri_pix = nn.MSELoss().to(self.device)
|
|
||||||
elif loss_type == 'cb':
|
|
||||||
self.cri_pix = CharbonnierLoss().to(self.device)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
|
|
||||||
self.l_pix_w = train_opt['pixel_weight']
|
|
||||||
|
|
||||||
# optimizers
|
|
||||||
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
|
||||||
optim_params = []
|
|
||||||
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
|
||||||
if v.requires_grad:
|
|
||||||
optim_params.append(v)
|
|
||||||
else:
|
|
||||||
if self.rank <= 0:
|
|
||||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
|
||||||
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
|
||||||
weight_decay=wd_G,
|
|
||||||
betas=(train_opt['beta1'], train_opt['beta2']))
|
|
||||||
self.optimizers.append(self.optimizer_G)
|
|
||||||
|
|
||||||
# schedulers
|
|
||||||
if train_opt['lr_scheme'] == 'MultiStepLR':
|
|
||||||
for optimizer in self.optimizers:
|
|
||||||
self.schedulers.append(
|
|
||||||
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
|
|
||||||
restarts=train_opt['restarts'],
|
|
||||||
weights=train_opt['restart_weights'],
|
|
||||||
gamma=train_opt['lr_gamma'],
|
|
||||||
clear_state=train_opt['clear_state']))
|
|
||||||
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
|
||||||
for optimizer in self.optimizers:
|
|
||||||
self.schedulers.append(
|
|
||||||
lr_scheduler.CosineAnnealingLR_Restart(
|
|
||||||
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
|
|
||||||
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
|
||||||
|
|
||||||
self.log_dict = OrderedDict()
|
|
||||||
|
|
||||||
def feed_data(self, data, need_GT=True):
|
|
||||||
self.var_L = data['LQ'].to(self.device) # LQ
|
|
||||||
if need_GT:
|
|
||||||
self.real_H = data['GT'].to(self.device) # GT
|
|
||||||
|
|
||||||
def optimize_parameters(self, step):
|
|
||||||
self.optimizer_G.zero_grad()
|
|
||||||
self.fake_H = self.netG(self.var_L)
|
|
||||||
l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
|
|
||||||
l_pix.backward()
|
|
||||||
self.optimizer_G.step()
|
|
||||||
|
|
||||||
# set log
|
|
||||||
self.log_dict['l_pix'] = l_pix.item()
|
|
||||||
|
|
||||||
def test(self):
|
|
||||||
self.netG.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
self.fake_H = self.netG(self.var_L)
|
|
||||||
self.netG.train()
|
|
||||||
|
|
||||||
def test_x8(self):
|
|
||||||
# from https://github.com/thstkdgus35/EDSR-PyTorch
|
|
||||||
self.netG.eval()
|
|
||||||
|
|
||||||
def _transform(v, op):
|
|
||||||
# if self.precision != 'single': v = v.float()
|
|
||||||
v2np = v.data.cpu().numpy()
|
|
||||||
if op == 'v':
|
|
||||||
tfnp = v2np[:, :, :, ::-1].copy()
|
|
||||||
elif op == 'h':
|
|
||||||
tfnp = v2np[:, :, ::-1, :].copy()
|
|
||||||
elif op == 't':
|
|
||||||
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
|
|
||||||
|
|
||||||
ret = torch.Tensor(tfnp).to(self.device)
|
|
||||||
# if self.precision == 'half': ret = ret.half()
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
lr_list = [self.var_L]
|
|
||||||
for tf in 'v', 'h', 't':
|
|
||||||
lr_list.extend([_transform(t, tf) for t in lr_list])
|
|
||||||
with torch.no_grad():
|
|
||||||
sr_list = [self.netG(aug) for aug in lr_list]
|
|
||||||
for i in range(len(sr_list)):
|
|
||||||
if i > 3:
|
|
||||||
sr_list[i] = _transform(sr_list[i], 't')
|
|
||||||
if i % 4 > 1:
|
|
||||||
sr_list[i] = _transform(sr_list[i], 'h')
|
|
||||||
if (i % 4) % 2 == 1:
|
|
||||||
sr_list[i] = _transform(sr_list[i], 'v')
|
|
||||||
|
|
||||||
output_cat = torch.cat(sr_list, dim=0)
|
|
||||||
self.fake_H = output_cat.mean(dim=0, keepdim=True)
|
|
||||||
self.netG.train()
|
|
||||||
|
|
||||||
def get_current_log(self):
|
|
||||||
return self.log_dict
|
|
||||||
|
|
||||||
def get_current_visuals(self, need_GT=True):
|
|
||||||
out_dict = OrderedDict()
|
|
||||||
out_dict['LQ'] = self.var_L.detach().float().cpu()
|
|
||||||
out_dict['rlt'] = self.fake_H.detach().float().cpu()
|
|
||||||
if need_GT:
|
|
||||||
out_dict['GT'] = self.real_H.detach().float().cpu()
|
|
||||||
return out_dict
|
|
||||||
|
|
||||||
def print_network(self):
|
|
||||||
s, n = self.get_network_description(self.netG)
|
|
||||||
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
|
|
||||||
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
|
||||||
self.netG.module.__class__.__name__)
|
|
||||||
else:
|
|
||||||
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
|
||||||
if self.rank <= 0:
|
|
||||||
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
|
||||||
logger.info(s)
|
|
||||||
|
|
||||||
def load(self):
|
|
||||||
load_path_G = self.opt['path']['pretrain_model_G']
|
|
||||||
if load_path_G is not None:
|
|
||||||
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
|
|
||||||
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
|
|
||||||
|
|
||||||
def save(self, iter_label):
|
|
||||||
self.save_network(self.netG, 'G', iter_label)
|
|
|
@ -1,22 +0,0 @@
|
||||||
import logging
|
|
||||||
logger = logging.getLogger('base')
|
|
||||||
|
|
||||||
|
|
||||||
def create_model(opt):
|
|
||||||
model = opt['model']
|
|
||||||
# image restoration
|
|
||||||
if model == 'sr': # PSNR-oriented super resolution
|
|
||||||
from .SR_model import SRModel as M
|
|
||||||
elif model == 'srgan' or model == 'corruptgan' or model == 'spsrgan':
|
|
||||||
from .SRGAN_model import SRGANModel as M
|
|
||||||
elif model == 'feat':
|
|
||||||
from .feature_model import FeatureModel as M
|
|
||||||
elif model == 'spsr':
|
|
||||||
from .SPSR_model import SPSRModel as M
|
|
||||||
elif model == 'extensibletrainer':
|
|
||||||
from .ExtensibleTrainer import ExtensibleTrainer as M
|
|
||||||
else:
|
|
||||||
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
|
|
||||||
m = M(opt)
|
|
||||||
logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
|
|
||||||
return m
|
|
|
@ -1,39 +0,0 @@
|
||||||
from torch.autograd import Function, Variable
|
|
||||||
from torch.nn.modules.module import Module
|
|
||||||
import channelnorm_cuda
|
|
||||||
|
|
||||||
class ChannelNormFunction(Function):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input1, norm_deg=2):
|
|
||||||
assert input1.is_contiguous()
|
|
||||||
b, _, h, w = input1.size()
|
|
||||||
output = input1.new(b, 1, h, w).zero_()
|
|
||||||
|
|
||||||
channelnorm_cuda.forward(input1, output, norm_deg)
|
|
||||||
ctx.save_for_backward(input1, output)
|
|
||||||
ctx.norm_deg = norm_deg
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
input1, output = ctx.saved_tensors
|
|
||||||
|
|
||||||
grad_input1 = Variable(input1.new(input1.size()).zero_())
|
|
||||||
|
|
||||||
channelnorm_cuda.backward(input1, output, grad_output.data,
|
|
||||||
grad_input1.data, ctx.norm_deg)
|
|
||||||
|
|
||||||
return grad_input1, None
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelNorm(Module):
|
|
||||||
|
|
||||||
def __init__(self, norm_deg=2):
|
|
||||||
super(ChannelNorm, self).__init__()
|
|
||||||
self.norm_deg = norm_deg
|
|
||||||
|
|
||||||
def forward(self, input1):
|
|
||||||
return ChannelNormFunction.apply(input1, self.norm_deg)
|
|
||||||
|
|
|
@ -1,31 +0,0 @@
|
||||||
#include <torch/torch.h>
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
|
|
||||||
#include "channelnorm_kernel.cuh"
|
|
||||||
|
|
||||||
int channelnorm_cuda_forward(
|
|
||||||
at::Tensor& input1,
|
|
||||||
at::Tensor& output,
|
|
||||||
int norm_deg) {
|
|
||||||
|
|
||||||
channelnorm_kernel_forward(input1, output, norm_deg);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
int channelnorm_cuda_backward(
|
|
||||||
at::Tensor& input1,
|
|
||||||
at::Tensor& output,
|
|
||||||
at::Tensor& gradOutput,
|
|
||||||
at::Tensor& gradInput1,
|
|
||||||
int norm_deg) {
|
|
||||||
|
|
||||||
channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)");
|
|
||||||
m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)");
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
|
|
||||||
void channelnorm_kernel_forward(
|
|
||||||
at::Tensor& input1,
|
|
||||||
at::Tensor& output,
|
|
||||||
int norm_deg);
|
|
||||||
|
|
||||||
|
|
||||||
void channelnorm_kernel_backward(
|
|
||||||
at::Tensor& input1,
|
|
||||||
at::Tensor& output,
|
|
||||||
at::Tensor& gradOutput,
|
|
||||||
at::Tensor& gradInput1,
|
|
||||||
int norm_deg);
|
|
|
@ -1,28 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from setuptools import setup
|
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
|
||||||
|
|
||||||
cxx_args = ['-std=c++11']
|
|
||||||
|
|
||||||
nvcc_args = [
|
|
||||||
'-gencode', 'arch=compute_52,code=sm_52',
|
|
||||||
'-gencode', 'arch=compute_60,code=sm_60',
|
|
||||||
'-gencode', 'arch=compute_61,code=sm_61',
|
|
||||||
'-gencode', 'arch=compute_70,code=sm_70',
|
|
||||||
'-gencode', 'arch=compute_70,code=compute_70'
|
|
||||||
]
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name='channelnorm_cuda',
|
|
||||||
ext_modules=[
|
|
||||||
CUDAExtension('channelnorm_cuda', [
|
|
||||||
'channelnorm_cuda.cc',
|
|
||||||
'channelnorm_kernel.cu'
|
|
||||||
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
|
|
||||||
],
|
|
||||||
cmdclass={
|
|
||||||
'build_ext': BuildExtension
|
|
||||||
})
|
|
|
@ -1,61 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch.nn.modules.module import Module
|
|
||||||
from torch.autograd import Function
|
|
||||||
import correlation_cuda
|
|
||||||
|
|
||||||
class CorrelationFunction(Function):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input1, input2, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
|
|
||||||
ctx.save_for_backward(input1, input2)
|
|
||||||
|
|
||||||
ctx.pad_size = pad_size
|
|
||||||
ctx.kernel_size = kernel_size
|
|
||||||
ctx.max_displacement = max_displacement
|
|
||||||
ctx.stride1 = stride1
|
|
||||||
ctx.stride2 = stride2
|
|
||||||
ctx.corr_multiply = corr_multiply
|
|
||||||
|
|
||||||
with torch.cuda.device_of(input1):
|
|
||||||
rbot1 = input1.new()
|
|
||||||
rbot2 = input2.new()
|
|
||||||
output = input1.new()
|
|
||||||
|
|
||||||
correlation_cuda.forward(input1, input2, rbot1, rbot2, output,
|
|
||||||
ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
input1, input2 = ctx.saved_tensors
|
|
||||||
|
|
||||||
with torch.cuda.device_of(input1):
|
|
||||||
rbot1 = input1.new()
|
|
||||||
rbot2 = input2.new()
|
|
||||||
|
|
||||||
grad_input1 = input1.new()
|
|
||||||
grad_input2 = input2.new()
|
|
||||||
|
|
||||||
correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
|
|
||||||
ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply)
|
|
||||||
|
|
||||||
return grad_input1, grad_input2, None, None, None, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class Correlation(Module):
|
|
||||||
def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
|
|
||||||
super(Correlation, self).__init__()
|
|
||||||
self.pad_size = pad_size
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.max_displacement = max_displacement
|
|
||||||
self.stride1 = stride1
|
|
||||||
self.stride2 = stride2
|
|
||||||
self.corr_multiply = corr_multiply
|
|
||||||
|
|
||||||
def forward(self, input1, input2):
|
|
||||||
|
|
||||||
result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
|
@ -1,173 +0,0 @@
|
||||||
#include <torch/torch.h>
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
#include <ATen/Context.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "correlation_cuda_kernel.cuh"
|
|
||||||
|
|
||||||
int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
|
|
||||||
int pad_size,
|
|
||||||
int kernel_size,
|
|
||||||
int max_displacement,
|
|
||||||
int stride1,
|
|
||||||
int stride2,
|
|
||||||
int corr_type_multiply)
|
|
||||||
{
|
|
||||||
|
|
||||||
int batchSize = input1.size(0);
|
|
||||||
|
|
||||||
int nInputChannels = input1.size(1);
|
|
||||||
int inputHeight = input1.size(2);
|
|
||||||
int inputWidth = input1.size(3);
|
|
||||||
|
|
||||||
int kernel_radius = (kernel_size - 1) / 2;
|
|
||||||
int border_radius = kernel_radius + max_displacement;
|
|
||||||
|
|
||||||
int paddedInputHeight = inputHeight + 2 * pad_size;
|
|
||||||
int paddedInputWidth = inputWidth + 2 * pad_size;
|
|
||||||
|
|
||||||
int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
|
|
||||||
|
|
||||||
int outputHeight = ceil(static_cast<float>(paddedInputHeight - 2 * border_radius) / static_cast<float>(stride1));
|
|
||||||
int outputwidth = ceil(static_cast<float>(paddedInputWidth - 2 * border_radius) / static_cast<float>(stride1));
|
|
||||||
|
|
||||||
rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
|
|
||||||
rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
|
|
||||||
output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
|
|
||||||
|
|
||||||
rInput1.fill_(0);
|
|
||||||
rInput2.fill_(0);
|
|
||||||
output.fill_(0);
|
|
||||||
|
|
||||||
int success = correlation_forward_cuda_kernel(
|
|
||||||
output,
|
|
||||||
output.size(0),
|
|
||||||
output.size(1),
|
|
||||||
output.size(2),
|
|
||||||
output.size(3),
|
|
||||||
output.stride(0),
|
|
||||||
output.stride(1),
|
|
||||||
output.stride(2),
|
|
||||||
output.stride(3),
|
|
||||||
input1,
|
|
||||||
input1.size(1),
|
|
||||||
input1.size(2),
|
|
||||||
input1.size(3),
|
|
||||||
input1.stride(0),
|
|
||||||
input1.stride(1),
|
|
||||||
input1.stride(2),
|
|
||||||
input1.stride(3),
|
|
||||||
input2,
|
|
||||||
input2.size(1),
|
|
||||||
input2.stride(0),
|
|
||||||
input2.stride(1),
|
|
||||||
input2.stride(2),
|
|
||||||
input2.stride(3),
|
|
||||||
rInput1,
|
|
||||||
rInput2,
|
|
||||||
pad_size,
|
|
||||||
kernel_size,
|
|
||||||
max_displacement,
|
|
||||||
stride1,
|
|
||||||
stride2,
|
|
||||||
corr_type_multiply,
|
|
||||||
at::cuda::getCurrentCUDAStream()
|
|
||||||
//at::globalContext().getCurrentCUDAStream()
|
|
||||||
);
|
|
||||||
|
|
||||||
//check for errors
|
|
||||||
if (!success) {
|
|
||||||
AT_ERROR("CUDA call failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
return 1;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput,
|
|
||||||
at::Tensor& gradInput1, at::Tensor& gradInput2,
|
|
||||||
int pad_size,
|
|
||||||
int kernel_size,
|
|
||||||
int max_displacement,
|
|
||||||
int stride1,
|
|
||||||
int stride2,
|
|
||||||
int corr_type_multiply)
|
|
||||||
{
|
|
||||||
|
|
||||||
int batchSize = input1.size(0);
|
|
||||||
int nInputChannels = input1.size(1);
|
|
||||||
int paddedInputHeight = input1.size(2)+ 2 * pad_size;
|
|
||||||
int paddedInputWidth = input1.size(3)+ 2 * pad_size;
|
|
||||||
|
|
||||||
int height = input1.size(2);
|
|
||||||
int width = input1.size(3);
|
|
||||||
|
|
||||||
rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
|
|
||||||
rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
|
|
||||||
gradInput1.resize_({batchSize, nInputChannels, height, width});
|
|
||||||
gradInput2.resize_({batchSize, nInputChannels, height, width});
|
|
||||||
|
|
||||||
rInput1.fill_(0);
|
|
||||||
rInput2.fill_(0);
|
|
||||||
gradInput1.fill_(0);
|
|
||||||
gradInput2.fill_(0);
|
|
||||||
|
|
||||||
int success = correlation_backward_cuda_kernel(gradOutput,
|
|
||||||
gradOutput.size(0),
|
|
||||||
gradOutput.size(1),
|
|
||||||
gradOutput.size(2),
|
|
||||||
gradOutput.size(3),
|
|
||||||
gradOutput.stride(0),
|
|
||||||
gradOutput.stride(1),
|
|
||||||
gradOutput.stride(2),
|
|
||||||
gradOutput.stride(3),
|
|
||||||
input1,
|
|
||||||
input1.size(1),
|
|
||||||
input1.size(2),
|
|
||||||
input1.size(3),
|
|
||||||
input1.stride(0),
|
|
||||||
input1.stride(1),
|
|
||||||
input1.stride(2),
|
|
||||||
input1.stride(3),
|
|
||||||
input2,
|
|
||||||
input2.stride(0),
|
|
||||||
input2.stride(1),
|
|
||||||
input2.stride(2),
|
|
||||||
input2.stride(3),
|
|
||||||
gradInput1,
|
|
||||||
gradInput1.stride(0),
|
|
||||||
gradInput1.stride(1),
|
|
||||||
gradInput1.stride(2),
|
|
||||||
gradInput1.stride(3),
|
|
||||||
gradInput2,
|
|
||||||
gradInput2.size(1),
|
|
||||||
gradInput2.stride(0),
|
|
||||||
gradInput2.stride(1),
|
|
||||||
gradInput2.stride(2),
|
|
||||||
gradInput2.stride(3),
|
|
||||||
rInput1,
|
|
||||||
rInput2,
|
|
||||||
pad_size,
|
|
||||||
kernel_size,
|
|
||||||
max_displacement,
|
|
||||||
stride1,
|
|
||||||
stride2,
|
|
||||||
corr_type_multiply,
|
|
||||||
at::cuda::getCurrentCUDAStream()
|
|
||||||
//at::globalContext().getCurrentCUDAStream()
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!success) {
|
|
||||||
AT_ERROR("CUDA call failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
|
|
||||||
m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,91 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
#include <ATen/Context.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
|
|
||||||
int correlation_forward_cuda_kernel(at::Tensor& output,
|
|
||||||
int ob,
|
|
||||||
int oc,
|
|
||||||
int oh,
|
|
||||||
int ow,
|
|
||||||
int osb,
|
|
||||||
int osc,
|
|
||||||
int osh,
|
|
||||||
int osw,
|
|
||||||
|
|
||||||
at::Tensor& input1,
|
|
||||||
int ic,
|
|
||||||
int ih,
|
|
||||||
int iw,
|
|
||||||
int isb,
|
|
||||||
int isc,
|
|
||||||
int ish,
|
|
||||||
int isw,
|
|
||||||
|
|
||||||
at::Tensor& input2,
|
|
||||||
int gc,
|
|
||||||
int gsb,
|
|
||||||
int gsc,
|
|
||||||
int gsh,
|
|
||||||
int gsw,
|
|
||||||
|
|
||||||
at::Tensor& rInput1,
|
|
||||||
at::Tensor& rInput2,
|
|
||||||
int pad_size,
|
|
||||||
int kernel_size,
|
|
||||||
int max_displacement,
|
|
||||||
int stride1,
|
|
||||||
int stride2,
|
|
||||||
int corr_type_multiply,
|
|
||||||
cudaStream_t stream);
|
|
||||||
|
|
||||||
|
|
||||||
int correlation_backward_cuda_kernel(
|
|
||||||
at::Tensor& gradOutput,
|
|
||||||
int gob,
|
|
||||||
int goc,
|
|
||||||
int goh,
|
|
||||||
int gow,
|
|
||||||
int gosb,
|
|
||||||
int gosc,
|
|
||||||
int gosh,
|
|
||||||
int gosw,
|
|
||||||
|
|
||||||
at::Tensor& input1,
|
|
||||||
int ic,
|
|
||||||
int ih,
|
|
||||||
int iw,
|
|
||||||
int isb,
|
|
||||||
int isc,
|
|
||||||
int ish,
|
|
||||||
int isw,
|
|
||||||
|
|
||||||
at::Tensor& input2,
|
|
||||||
int gsb,
|
|
||||||
int gsc,
|
|
||||||
int gsh,
|
|
||||||
int gsw,
|
|
||||||
|
|
||||||
at::Tensor& gradInput1,
|
|
||||||
int gisb,
|
|
||||||
int gisc,
|
|
||||||
int gish,
|
|
||||||
int gisw,
|
|
||||||
|
|
||||||
at::Tensor& gradInput2,
|
|
||||||
int ggc,
|
|
||||||
int ggsb,
|
|
||||||
int ggsc,
|
|
||||||
int ggsh,
|
|
||||||
int ggsw,
|
|
||||||
|
|
||||||
at::Tensor& rInput1,
|
|
||||||
at::Tensor& rInput2,
|
|
||||||
int pad_size,
|
|
||||||
int kernel_size,
|
|
||||||
int max_displacement,
|
|
||||||
int stride1,
|
|
||||||
int stride2,
|
|
||||||
int corr_type_multiply,
|
|
||||||
cudaStream_t stream);
|
|
|
@ -1,29 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from setuptools import setup, find_packages
|
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
|
||||||
|
|
||||||
cxx_args = ['-std=c++11']
|
|
||||||
|
|
||||||
nvcc_args = [
|
|
||||||
'-gencode', 'arch=compute_50,code=sm_50',
|
|
||||||
'-gencode', 'arch=compute_52,code=sm_52',
|
|
||||||
'-gencode', 'arch=compute_60,code=sm_60',
|
|
||||||
'-gencode', 'arch=compute_61,code=sm_61',
|
|
||||||
'-gencode', 'arch=compute_70,code=sm_70',
|
|
||||||
'-gencode', 'arch=compute_70,code=compute_70'
|
|
||||||
]
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name='correlation_cuda',
|
|
||||||
ext_modules=[
|
|
||||||
CUDAExtension('correlation_cuda', [
|
|
||||||
'correlation_cuda.cc',
|
|
||||||
'correlation_cuda_kernel.cu'
|
|
||||||
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
|
|
||||||
],
|
|
||||||
cmdclass={
|
|
||||||
'build_ext': BuildExtension
|
|
||||||
})
|
|
|
@ -1,49 +0,0 @@
|
||||||
from torch.nn.modules.module import Module
|
|
||||||
from torch.autograd import Function, Variable
|
|
||||||
import resample2d_cuda
|
|
||||||
|
|
||||||
class Resample2dFunction(Function):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input1, input2, kernel_size=1, bilinear= True):
|
|
||||||
assert input1.is_contiguous()
|
|
||||||
assert input2.is_contiguous()
|
|
||||||
|
|
||||||
ctx.save_for_backward(input1, input2)
|
|
||||||
ctx.kernel_size = kernel_size
|
|
||||||
ctx.bilinear = bilinear
|
|
||||||
|
|
||||||
_, d, _, _ = input1.size()
|
|
||||||
b, _, h, w = input2.size()
|
|
||||||
output = input1.new(b, d, h, w).zero_()
|
|
||||||
|
|
||||||
resample2d_cuda.forward(input1, input2, output, kernel_size, bilinear)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
grad_output = grad_output.contiguous()
|
|
||||||
assert grad_output.is_contiguous()
|
|
||||||
|
|
||||||
input1, input2 = ctx.saved_tensors
|
|
||||||
|
|
||||||
grad_input1 = Variable(input1.new(input1.size()).zero_())
|
|
||||||
grad_input2 = Variable(input1.new(input2.size()).zero_())
|
|
||||||
|
|
||||||
resample2d_cuda.backward(input1, input2, grad_output.data,
|
|
||||||
grad_input1.data, grad_input2.data,
|
|
||||||
ctx.kernel_size, ctx.bilinear)
|
|
||||||
|
|
||||||
return grad_input1, grad_input2, None, None
|
|
||||||
|
|
||||||
class Resample2d(Module):
|
|
||||||
|
|
||||||
def __init__(self, kernel_size=1, bilinear = True):
|
|
||||||
super(Resample2d, self).__init__()
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.bilinear = bilinear
|
|
||||||
|
|
||||||
def forward(self, input1, input2):
|
|
||||||
input1_c = input1.contiguous()
|
|
||||||
return Resample2dFunction.apply(input1_c, input2, self.kernel_size, self.bilinear)
|
|
|
@ -1,32 +0,0 @@
|
||||||
#include <ATen/ATen.h>
|
|
||||||
#include <torch/torch.h>
|
|
||||||
|
|
||||||
#include "resample2d_kernel.cuh"
|
|
||||||
|
|
||||||
int resample2d_cuda_forward(
|
|
||||||
at::Tensor& input1,
|
|
||||||
at::Tensor& input2,
|
|
||||||
at::Tensor& output,
|
|
||||||
int kernel_size, bool bilinear) {
|
|
||||||
resample2d_kernel_forward(input1, input2, output, kernel_size, bilinear);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int resample2d_cuda_backward(
|
|
||||||
at::Tensor& input1,
|
|
||||||
at::Tensor& input2,
|
|
||||||
at::Tensor& gradOutput,
|
|
||||||
at::Tensor& gradInput1,
|
|
||||||
at::Tensor& gradInput2,
|
|
||||||
int kernel_size, bool bilinear) {
|
|
||||||
resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size, bilinear);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)");
|
|
||||||
m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)");
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,19 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
|
|
||||||
void resample2d_kernel_forward(
|
|
||||||
at::Tensor& input1,
|
|
||||||
at::Tensor& input2,
|
|
||||||
at::Tensor& output,
|
|
||||||
int kernel_size,
|
|
||||||
bool bilinear);
|
|
||||||
|
|
||||||
void resample2d_kernel_backward(
|
|
||||||
at::Tensor& input1,
|
|
||||||
at::Tensor& input2,
|
|
||||||
at::Tensor& gradOutput,
|
|
||||||
at::Tensor& gradInput1,
|
|
||||||
at::Tensor& gradInput2,
|
|
||||||
int kernel_size,
|
|
||||||
bool bilinear);
|
|
|
@ -1,29 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from setuptools import setup
|
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
|
||||||
|
|
||||||
cxx_args = ['-std=c++11']
|
|
||||||
|
|
||||||
nvcc_args = [
|
|
||||||
'-gencode', 'arch=compute_50,code=sm_50',
|
|
||||||
'-gencode', 'arch=compute_52,code=sm_52',
|
|
||||||
'-gencode', 'arch=compute_60,code=sm_60',
|
|
||||||
'-gencode', 'arch=compute_61,code=sm_61',
|
|
||||||
'-gencode', 'arch=compute_70,code=sm_70',
|
|
||||||
'-gencode', 'arch=compute_70,code=compute_70'
|
|
||||||
]
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name='resample2d_cuda',
|
|
||||||
ext_modules=[
|
|
||||||
CUDAExtension('resample2d_cuda', [
|
|
||||||
'resample2d_cuda.cc',
|
|
||||||
'resample2d_kernel.cu'
|
|
||||||
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
|
|
||||||
],
|
|
||||||
cmdclass={
|
|
||||||
'build_ext': BuildExtension
|
|
||||||
})
|
|
|
@ -49,7 +49,7 @@ class GANLoss(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
# Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL
|
# Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL
|
||||||
# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see data_scripts/compute_fdpl_perceptual_weights.py
|
# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see scripts/compute_fdpl_perceptual_weights.py
|
||||||
# In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs).
|
# In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs).
|
||||||
class FDPLLoss(nn.Module):
|
class FDPLLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,71 +0,0 @@
|
||||||
# Author Masashi Kimura (Convergence Lab.)
|
|
||||||
import torch
|
|
||||||
from torch import optim
|
|
||||||
import math
|
|
||||||
|
|
||||||
class NovoGrad(optim.Optimizer):
|
|
||||||
def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
|
|
||||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
|
||||||
super(NovoGrad, self).__init__(params, defaults)
|
|
||||||
self._lr = lr
|
|
||||||
self._beta1 = betas[0]
|
|
||||||
self._beta2 = betas[1]
|
|
||||||
self._eps = eps
|
|
||||||
self._wd = weight_decay
|
|
||||||
self._grad_averaging = grad_averaging
|
|
||||||
|
|
||||||
self._momentum_initialized = False
|
|
||||||
|
|
||||||
def step(self, closure=None):
|
|
||||||
loss = None
|
|
||||||
if closure is not None:
|
|
||||||
loss = closure()
|
|
||||||
|
|
||||||
if not self._momentum_initialized:
|
|
||||||
for group in self.param_groups:
|
|
||||||
for p in group['params']:
|
|
||||||
if p.grad is None:
|
|
||||||
continue
|
|
||||||
state = self.state[p]
|
|
||||||
grad = p.grad.data
|
|
||||||
if grad.is_sparse:
|
|
||||||
raise RuntimeError('NovoGrad does not support sparse gradients')
|
|
||||||
|
|
||||||
v = torch.norm(grad)**2
|
|
||||||
m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data
|
|
||||||
state['step'] = 0
|
|
||||||
state['v'] = v
|
|
||||||
state['m'] = m
|
|
||||||
state['grad_ema'] = None
|
|
||||||
self._momentum_initialized = True
|
|
||||||
|
|
||||||
for group in self.param_groups:
|
|
||||||
for p in group['params']:
|
|
||||||
if p.grad is None:
|
|
||||||
continue
|
|
||||||
state = self.state[p]
|
|
||||||
state['step'] += 1
|
|
||||||
|
|
||||||
step, v, m = state['step'], state['v'], state['m']
|
|
||||||
grad_ema = state['grad_ema']
|
|
||||||
|
|
||||||
grad = p.grad.data
|
|
||||||
g2 = torch.norm(grad)**2
|
|
||||||
grad_ema = g2 if grad_ema is None else grad_ema * \
|
|
||||||
self._beta2 + g2*(1. - self._beta2)
|
|
||||||
grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
|
|
||||||
|
|
||||||
if self._grad_averaging:
|
|
||||||
grad *= (1. - self._beta1)
|
|
||||||
|
|
||||||
g2 = torch.norm(grad)**2
|
|
||||||
v = self._beta2*v + (1. - self._beta2)*g2
|
|
||||||
m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd*p.data)
|
|
||||||
bias_correction1 = 1 - self._beta1 ** step
|
|
||||||
bias_correction2 = 1 - self._beta2 ** step
|
|
||||||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
|
||||||
|
|
||||||
state['v'], state['m'] = v, m
|
|
||||||
state['grad_ema'] = grad_ema
|
|
||||||
p.data.add_(-step_size, m)
|
|
||||||
return loss
|
|
|
@ -1,8 +1,6 @@
|
||||||
import torch.nn
|
import torch.nn
|
||||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||||
from data.weight_scheduler import get_scheduler_for_opt
|
from utils.weight_scheduler import get_scheduler_for_opt
|
||||||
from utils.util import checkpoint
|
|
||||||
import torchvision.utils as utils
|
|
||||||
#from models.steps.recursive_gen_injectors import ImageFlowInjector
|
#from models.steps.recursive_gen_injectors import ImageFlowInjector
|
||||||
from models.steps.losses import extract_params_from_state
|
from models.steps.losses import extract_params_from_state
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ import torch
|
||||||
from apex import amp
|
from apex import amp
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from .injectors import create_injector
|
from .injectors import create_injector
|
||||||
from models.novograd import NovoGrad
|
|
||||||
from utils.util import recursively_detach
|
from utils.util import recursively_detach
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
|
@ -1,32 +0,0 @@
|
||||||
name: RRDB_ESRGAN_x4
|
|
||||||
suffix: ~ # add suffix to saved images
|
|
||||||
model: sr
|
|
||||||
distortion: sr
|
|
||||||
scale: 4
|
|
||||||
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
|
||||||
gpu_ids: [0]
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
test_1: # the 1st test dataset
|
|
||||||
name: set5
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: ../datasets/val_set5/Set5
|
|
||||||
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
|
|
||||||
test_2: # the 2st test dataset
|
|
||||||
name: set14
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: ../datasets/val_set14/Set14
|
|
||||||
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
|
||||||
|
|
||||||
#### network structures
|
|
||||||
network_G:
|
|
||||||
which_model_G: RRDBNet
|
|
||||||
in_nc: 3
|
|
||||||
out_nc: 3
|
|
||||||
nf: 64
|
|
||||||
nb: 23
|
|
||||||
upscale: 4
|
|
||||||
|
|
||||||
#### path
|
|
||||||
path:
|
|
||||||
pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth
|
|
|
@ -1,26 +0,0 @@
|
||||||
name: RRDB_ESRGAN_x4
|
|
||||||
suffix: ~ # add suffix to saved images
|
|
||||||
model: sr
|
|
||||||
distortion: sr
|
|
||||||
scale: 4
|
|
||||||
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
|
||||||
gpu_ids: [0]
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
test_1: # the 1st test dataset
|
|
||||||
name: set14
|
|
||||||
mode: LQ
|
|
||||||
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
|
||||||
|
|
||||||
#### network structures
|
|
||||||
network_G:
|
|
||||||
which_model_G: RRDBNet
|
|
||||||
in_nc: 3
|
|
||||||
out_nc: 3
|
|
||||||
nf: 64
|
|
||||||
nb: 23
|
|
||||||
upscale: 4
|
|
||||||
|
|
||||||
#### path
|
|
||||||
path:
|
|
||||||
pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth
|
|
|
@ -1,32 +0,0 @@
|
||||||
name: MSRGANx4
|
|
||||||
suffix: ~ # add suffix to saved images
|
|
||||||
model: sr
|
|
||||||
distortion: sr
|
|
||||||
scale: 4
|
|
||||||
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
|
||||||
gpu_ids: [0]
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
test_1: # the 1st test dataset
|
|
||||||
name: set5
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: ../datasets/val_set5/Set5
|
|
||||||
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
|
|
||||||
test_2: # the 2st test dataset
|
|
||||||
name: set14
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: ../datasets/val_set14/Set14
|
|
||||||
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
|
||||||
|
|
||||||
#### network structures
|
|
||||||
network_G:
|
|
||||||
which_model_G: MSRResNet
|
|
||||||
in_nc: 3
|
|
||||||
out_nc: 3
|
|
||||||
nf: 64
|
|
||||||
nb: 16
|
|
||||||
upscale: 4
|
|
||||||
|
|
||||||
#### path
|
|
||||||
path:
|
|
||||||
pretrain_model_G: ../experiments/pretrained_models/MSRGANx4.pth
|
|
|
@ -1,48 +0,0 @@
|
||||||
name: MSRResNetx4
|
|
||||||
suffix: ~ # add suffix to saved images
|
|
||||||
model: sr
|
|
||||||
distortion: sr
|
|
||||||
scale: 4
|
|
||||||
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
|
||||||
gpu_ids: [0]
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
test_1: # the 1st test dataset
|
|
||||||
name: set5
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: ../datasets/val_set5/Set5
|
|
||||||
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
|
|
||||||
test_2: # the 2st test dataset
|
|
||||||
name: set14
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: ../datasets/val_set14/Set14
|
|
||||||
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
|
||||||
test_3:
|
|
||||||
name: bsd100
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: ../datasets/BSD/BSDS100
|
|
||||||
dataroot_LQ: ../datasets/BSD/BSDS100_bicLRx4
|
|
||||||
test_4:
|
|
||||||
name: urban100
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: ../datasets/urban100
|
|
||||||
dataroot_LQ: ../datasets/urban100_bicLRx4
|
|
||||||
test_5:
|
|
||||||
name: div2k100
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: ../datasets/DIV2K100/DIV2K_valid_HR
|
|
||||||
dataroot_LQ: ../datasets/DIV2K100/DIV2K_valid_bicLRx4
|
|
||||||
|
|
||||||
|
|
||||||
#### network structures
|
|
||||||
network_G:
|
|
||||||
which_model_G: MSRResNet
|
|
||||||
in_nc: 3
|
|
||||||
out_nc: 3
|
|
||||||
nf: 64
|
|
||||||
nb: 16
|
|
||||||
upscale: 4
|
|
||||||
|
|
||||||
#### path
|
|
||||||
path:
|
|
||||||
pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth
|
|
|
@ -1,80 +0,0 @@
|
||||||
#### general settings
|
|
||||||
name: 002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new
|
|
||||||
use_tb_logger: true
|
|
||||||
model: video_base
|
|
||||||
distortion: sr
|
|
||||||
scale: 4
|
|
||||||
gpu_ids: [0,1,2,3,4,5,6,7]
|
|
||||||
|
|
||||||
#### datasets
|
|
||||||
datasets:
|
|
||||||
train:
|
|
||||||
name: REDS
|
|
||||||
mode: REDS
|
|
||||||
interval_list: [1]
|
|
||||||
random_reverse: false
|
|
||||||
border_mode: false
|
|
||||||
dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
|
|
||||||
dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
|
|
||||||
cache_keys: ~
|
|
||||||
|
|
||||||
N_frames: 5
|
|
||||||
use_shuffle: true
|
|
||||||
n_workers: 3 # per GPU
|
|
||||||
batch_size: 32
|
|
||||||
target_size: 256
|
|
||||||
LQ_size: 64
|
|
||||||
use_flip: true
|
|
||||||
use_rot: true
|
|
||||||
color: RGB
|
|
||||||
val:
|
|
||||||
name: REDS4
|
|
||||||
mode: video_test
|
|
||||||
dataroot_GT: ../datasets/REDS4/GT
|
|
||||||
dataroot_LQ: ../datasets/REDS4/sharp_bicubic
|
|
||||||
cache_data: True
|
|
||||||
N_frames: 5
|
|
||||||
padding: new_info
|
|
||||||
|
|
||||||
#### network structures
|
|
||||||
network_G:
|
|
||||||
which_model_G: EDVR
|
|
||||||
nf: 64
|
|
||||||
nframes: 5
|
|
||||||
groups: 8
|
|
||||||
front_RBs: 5
|
|
||||||
back_RBs: 10
|
|
||||||
predeblur: false
|
|
||||||
HR_in: false
|
|
||||||
w_TSA: true
|
|
||||||
|
|
||||||
#### path
|
|
||||||
path:
|
|
||||||
pretrain_model_G: ../experiments/pretrained_models/EDVR_REDS_SR_M_woTSA.pth
|
|
||||||
strict_load: false
|
|
||||||
resume_state: ~
|
|
||||||
|
|
||||||
#### training settings: learning rate scheme, loss
|
|
||||||
train:
|
|
||||||
lr_G: !!float 4e-4
|
|
||||||
lr_scheme: CosineAnnealingLR_Restart
|
|
||||||
beta1: 0.9
|
|
||||||
beta2: 0.99
|
|
||||||
niter: 600000
|
|
||||||
ft_tsa_only: 50000
|
|
||||||
warmup_iter: -1 # -1: no warm up
|
|
||||||
T_period: [50000, 100000, 150000, 150000, 150000]
|
|
||||||
restarts: [50000, 150000, 300000, 450000]
|
|
||||||
restart_weights: [1, 1, 1, 1]
|
|
||||||
eta_min: !!float 1e-7
|
|
||||||
|
|
||||||
pixel_criterion: cb
|
|
||||||
pixel_weight: 1.0
|
|
||||||
val_freq: !!float 5e3
|
|
||||||
|
|
||||||
manual_seed: 0
|
|
||||||
|
|
||||||
#### logger
|
|
||||||
logger:
|
|
||||||
print_freq: 100
|
|
||||||
save_checkpoint_freq: !!float 5e3
|
|
|
@ -1,71 +0,0 @@
|
||||||
#### general settings
|
|
||||||
name: 001_EDVRwoTSA_scratch_lr4e-4_600k_REDS_LrCAR4S
|
|
||||||
use_tb_logger: true
|
|
||||||
model: video_base
|
|
||||||
distortion: sr
|
|
||||||
scale: 4
|
|
||||||
gpu_ids: [0,1,2,3,4,5,6,7]
|
|
||||||
|
|
||||||
#### datasets
|
|
||||||
datasets:
|
|
||||||
train:
|
|
||||||
name: REDS
|
|
||||||
mode: REDS
|
|
||||||
interval_list: [1]
|
|
||||||
random_reverse: false
|
|
||||||
border_mode: false
|
|
||||||
dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
|
|
||||||
dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
|
|
||||||
cache_keys: ~
|
|
||||||
|
|
||||||
N_frames: 5
|
|
||||||
use_shuffle: true
|
|
||||||
n_workers: 3 # per GPU
|
|
||||||
batch_size: 32
|
|
||||||
target_size: 256
|
|
||||||
LQ_size: 64
|
|
||||||
use_flip: true
|
|
||||||
use_rot: true
|
|
||||||
color: RGB
|
|
||||||
|
|
||||||
#### network structures
|
|
||||||
network_G:
|
|
||||||
which_model_G: EDVR
|
|
||||||
nf: 64
|
|
||||||
nframes: 5
|
|
||||||
groups: 8
|
|
||||||
front_RBs: 5
|
|
||||||
back_RBs: 10
|
|
||||||
predeblur: false
|
|
||||||
HR_in: false
|
|
||||||
w_TSA: false
|
|
||||||
|
|
||||||
#### path
|
|
||||||
path:
|
|
||||||
pretrain_model_G: ~
|
|
||||||
strict_load: true
|
|
||||||
resume_state: ~
|
|
||||||
|
|
||||||
#### training settings: learning rate scheme, loss
|
|
||||||
train:
|
|
||||||
lr_G: !!float 4e-4
|
|
||||||
lr_scheme: CosineAnnealingLR_Restart
|
|
||||||
beta1: 0.9
|
|
||||||
beta2: 0.99
|
|
||||||
niter: 600000
|
|
||||||
warmup_iter: -1 # -1: no warm up
|
|
||||||
T_period: [150000, 150000, 150000, 150000]
|
|
||||||
restarts: [150000, 300000, 450000]
|
|
||||||
restart_weights: [1, 1, 1]
|
|
||||||
eta_min: !!float 1e-7
|
|
||||||
|
|
||||||
pixel_criterion: cb
|
|
||||||
pixel_weight: 1.0
|
|
||||||
val_freq: !!float 5e3
|
|
||||||
|
|
||||||
manual_seed: 0
|
|
||||||
|
|
||||||
#### logger
|
|
||||||
logger:
|
|
||||||
print_freq: 100
|
|
||||||
save_checkpoint_freq: !!float 5e3
|
|
|
@ -1,82 +0,0 @@
|
||||||
#### general settings
|
|
||||||
name: 003_RRDB_ESRGANx4_DIV2K
|
|
||||||
use_tb_logger: true
|
|
||||||
model: srgan
|
|
||||||
distortion: sr
|
|
||||||
scale: 4
|
|
||||||
gpu_ids: [0]
|
|
||||||
amp_opt_level: O1
|
|
||||||
|
|
||||||
#### datasets
|
|
||||||
datasets:
|
|
||||||
train:
|
|
||||||
name: DIV2K
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub
|
|
||||||
dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4
|
|
||||||
|
|
||||||
use_shuffle: true
|
|
||||||
n_workers: 16 # per GPU
|
|
||||||
batch_size: 16
|
|
||||||
target_size: 128
|
|
||||||
use_flip: true
|
|
||||||
use_rot: true
|
|
||||||
color: RGB
|
|
||||||
val:
|
|
||||||
name: div2kval
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr
|
|
||||||
dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic
|
|
||||||
|
|
||||||
#### network structures
|
|
||||||
network_G:
|
|
||||||
which_model_G: ResGen
|
|
||||||
nf: 256
|
|
||||||
network_D:
|
|
||||||
which_model_D: discriminator_resnet_passthrough
|
|
||||||
nf: 42
|
|
||||||
|
|
||||||
#### path
|
|
||||||
path:
|
|
||||||
pretrain_model_G: ~
|
|
||||||
strict_load: true
|
|
||||||
resume_state: ~
|
|
||||||
|
|
||||||
#### training settings: learning rate scheme, loss
|
|
||||||
train:
|
|
||||||
lr_G: !!float 1e-4
|
|
||||||
weight_decay_G: 0
|
|
||||||
beta1_G: 0.9
|
|
||||||
beta2_G: 0.99
|
|
||||||
lr_D: !!float 1e-4
|
|
||||||
weight_decay_D: 0
|
|
||||||
beta1_D: 0.9
|
|
||||||
beta2_D: 0.99
|
|
||||||
lr_scheme: MultiStepLR
|
|
||||||
|
|
||||||
niter: 400000
|
|
||||||
warmup_iter: -1 # no warm up
|
|
||||||
lr_steps: [50000, 100000, 200000, 300000]
|
|
||||||
lr_gamma: 0.5
|
|
||||||
mega_batch_factor: 1
|
|
||||||
|
|
||||||
pixel_criterion: l1
|
|
||||||
pixel_weight: !!float 1e-2
|
|
||||||
feature_criterion: l1
|
|
||||||
feature_weight: 1
|
|
||||||
feature_weight_decay: .98
|
|
||||||
feature_weight_decay_steps: 500
|
|
||||||
feature_weight_minimum: .1
|
|
||||||
gan_type: gan # gan | ragan
|
|
||||||
gan_weight: !!float 5e-3
|
|
||||||
|
|
||||||
D_update_ratio: 2
|
|
||||||
D_init_iters: 0
|
|
||||||
|
|
||||||
manual_seed: 10
|
|
||||||
val_freq: !!float 5e2
|
|
||||||
|
|
||||||
#### logger
|
|
||||||
logger:
|
|
||||||
print_freq: 50
|
|
||||||
save_checkpoint_freq: !!float 5e2
|
|
|
@ -1,85 +0,0 @@
|
||||||
#### general settings
|
|
||||||
name: esrgan_res
|
|
||||||
use_tb_logger: true
|
|
||||||
model: srgan
|
|
||||||
distortion: sr
|
|
||||||
scale: 4
|
|
||||||
gpu_ids: [0]
|
|
||||||
amp_opt_level: O1
|
|
||||||
|
|
||||||
#### datasets
|
|
||||||
datasets:
|
|
||||||
train:
|
|
||||||
name: DIV2K
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub
|
|
||||||
dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4
|
|
||||||
|
|
||||||
use_shuffle: true
|
|
||||||
n_workers: 0 # per GPU
|
|
||||||
batch_size: 24
|
|
||||||
target_size: 128
|
|
||||||
use_flip: true
|
|
||||||
use_rot: true
|
|
||||||
color: RGB
|
|
||||||
val:
|
|
||||||
name: div2kval
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr
|
|
||||||
dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic
|
|
||||||
|
|
||||||
#### network structures
|
|
||||||
network_G:
|
|
||||||
which_model_G: ResGen
|
|
||||||
nf: 256
|
|
||||||
nb_denoiser: 2
|
|
||||||
nb_upsampler: 28
|
|
||||||
network_D:
|
|
||||||
which_model_D: discriminator_resnet_passthrough
|
|
||||||
nf: 42
|
|
||||||
|
|
||||||
#### path
|
|
||||||
path:
|
|
||||||
#pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth
|
|
||||||
#pretrain_model_D: ~
|
|
||||||
strict_load: true
|
|
||||||
resume_state: ../experiments/esrgan_res/training_state/15500.state
|
|
||||||
|
|
||||||
#### training settings: learning rate scheme, loss
|
|
||||||
train:
|
|
||||||
lr_G: !!float 1e-4
|
|
||||||
weight_decay_G: 0
|
|
||||||
beta1_G: 0.9
|
|
||||||
beta2_G: 0.99
|
|
||||||
lr_D: !!float 1e-4
|
|
||||||
weight_decay_D: 0
|
|
||||||
beta1_D: 0.9
|
|
||||||
beta2_D: 0.99
|
|
||||||
lr_scheme: MultiStepLR
|
|
||||||
|
|
||||||
niter: 400000
|
|
||||||
warmup_iter: -1 # no warm up
|
|
||||||
lr_steps: [20000, 40000, 50000, 60000]
|
|
||||||
lr_gamma: 0.5
|
|
||||||
mega_batch_factor: 2
|
|
||||||
|
|
||||||
pixel_criterion: l1
|
|
||||||
pixel_weight: !!float 1e-2
|
|
||||||
feature_criterion: l1
|
|
||||||
feature_weight: 1
|
|
||||||
feature_weight_decay: 1
|
|
||||||
feature_weight_decay_steps: 500
|
|
||||||
feature_weight_minimum: 1
|
|
||||||
gan_type: gan # gan | ragan
|
|
||||||
gan_weight: !!float 1e-2
|
|
||||||
|
|
||||||
D_update_ratio: 2
|
|
||||||
D_init_iters: -1
|
|
||||||
|
|
||||||
manual_seed: 10
|
|
||||||
val_freq: !!float 5e2
|
|
||||||
|
|
||||||
#### logger
|
|
||||||
logger:
|
|
||||||
print_freq: 50
|
|
||||||
save_checkpoint_freq: !!float 5e2
|
|
|
@ -1,85 +0,0 @@
|
||||||
# Not exactly the same as SRGAN in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
|
|
||||||
# With 16 Residual blocks w/o BN
|
|
||||||
|
|
||||||
#### general settings
|
|
||||||
name: 002_SRGANx4_MSRResNetx4Ini_DIV2K
|
|
||||||
use_tb_logger: true
|
|
||||||
model: srgan
|
|
||||||
distortion: sr
|
|
||||||
scale: 4
|
|
||||||
gpu_ids: [1]
|
|
||||||
|
|
||||||
#### datasets
|
|
||||||
datasets:
|
|
||||||
train:
|
|
||||||
name: DIV2K
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
|
|
||||||
dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
|
|
||||||
|
|
||||||
use_shuffle: true
|
|
||||||
n_workers: 6 # per GPU
|
|
||||||
batch_size: 16
|
|
||||||
target_size: 128
|
|
||||||
use_flip: true
|
|
||||||
use_rot: true
|
|
||||||
color: RGB
|
|
||||||
val:
|
|
||||||
name: val_set14
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: ../datasets/val_set14/Set14
|
|
||||||
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
|
||||||
|
|
||||||
#### network structures
|
|
||||||
network_G:
|
|
||||||
which_model_G: MSRResNet
|
|
||||||
in_nc: 3
|
|
||||||
out_nc: 3
|
|
||||||
nf: 64
|
|
||||||
nb: 16
|
|
||||||
upscale: 4
|
|
||||||
network_D:
|
|
||||||
which_model_D: discriminator_vgg_128
|
|
||||||
in_nc: 3
|
|
||||||
nf: 64
|
|
||||||
|
|
||||||
#### path
|
|
||||||
path:
|
|
||||||
pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth
|
|
||||||
strict_load: true
|
|
||||||
resume_state: ~
|
|
||||||
|
|
||||||
#### training settings: learning rate scheme, loss
|
|
||||||
train:
|
|
||||||
lr_G: !!float 1e-4
|
|
||||||
weight_decay_G: 0
|
|
||||||
beta1_G: 0.9
|
|
||||||
beta2_G: 0.99
|
|
||||||
lr_D: !!float 1e-4
|
|
||||||
weight_decay_D: 0
|
|
||||||
beta1_D: 0.9
|
|
||||||
beta2_D: 0.99
|
|
||||||
lr_scheme: MultiStepLR
|
|
||||||
|
|
||||||
niter: 400000
|
|
||||||
warmup_iter: -1 # no warm up
|
|
||||||
lr_steps: [50000, 100000, 200000, 300000]
|
|
||||||
lr_gamma: 0.5
|
|
||||||
|
|
||||||
pixel_criterion: l1
|
|
||||||
pixel_weight: !!float 1e-2
|
|
||||||
feature_criterion: l1
|
|
||||||
feature_weight: 1
|
|
||||||
gan_type: gan # gan | ragan
|
|
||||||
gan_weight: !!float 5e-3
|
|
||||||
|
|
||||||
D_update_ratio: 1
|
|
||||||
D_init_iters: 0
|
|
||||||
|
|
||||||
manual_seed: 10
|
|
||||||
val_freq: !!float 5e3
|
|
||||||
|
|
||||||
#### logger
|
|
||||||
logger:
|
|
||||||
print_freq: 100
|
|
||||||
save_checkpoint_freq: !!float 5e3
|
|
|
@ -1,70 +0,0 @@
|
||||||
# Not exactly the same as SRResNet in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
|
|
||||||
# With 16 Residual blocks w/o BN
|
|
||||||
|
|
||||||
#### general settings
|
|
||||||
name: 001_MSRResNetx4_scratch_DIV2K
|
|
||||||
use_tb_logger: true
|
|
||||||
model: sr
|
|
||||||
distortion: sr
|
|
||||||
scale: 4
|
|
||||||
gpu_ids: [0]
|
|
||||||
|
|
||||||
#### datasets
|
|
||||||
datasets:
|
|
||||||
train:
|
|
||||||
name: DIV2K
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
|
|
||||||
dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
|
|
||||||
|
|
||||||
use_shuffle: true
|
|
||||||
n_workers: 6 # per GPU
|
|
||||||
batch_size: 16
|
|
||||||
target_size: 128
|
|
||||||
use_flip: true
|
|
||||||
use_rot: true
|
|
||||||
color: RGB
|
|
||||||
val:
|
|
||||||
name: val_set5
|
|
||||||
mode: LQGT
|
|
||||||
dataroot_GT: ../datasets/val_set5/Set5
|
|
||||||
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
|
|
||||||
|
|
||||||
#### network structures
|
|
||||||
network_G:
|
|
||||||
which_model_G: MSRResNet
|
|
||||||
in_nc: 3
|
|
||||||
out_nc: 3
|
|
||||||
nf: 64
|
|
||||||
nb: 16
|
|
||||||
upscale: 4
|
|
||||||
|
|
||||||
#### path
|
|
||||||
path:
|
|
||||||
pretrain_model_G: ~
|
|
||||||
strict_load: true
|
|
||||||
resume_state: ~
|
|
||||||
|
|
||||||
#### training settings: learning rate scheme, loss
|
|
||||||
train:
|
|
||||||
lr_G: !!float 2e-4
|
|
||||||
lr_scheme: CosineAnnealingLR_Restart
|
|
||||||
beta1: 0.9
|
|
||||||
beta2: 0.99
|
|
||||||
niter: 1000000
|
|
||||||
warmup_iter: -1 # no warm up
|
|
||||||
T_period: [250000, 250000, 250000, 250000]
|
|
||||||
restarts: [250000, 500000, 750000]
|
|
||||||
restart_weights: [1, 1, 1]
|
|
||||||
eta_min: !!float 1e-7
|
|
||||||
|
|
||||||
pixel_criterion: l1
|
|
||||||
pixel_weight: 1.0
|
|
||||||
|
|
||||||
manual_seed: 10
|
|
||||||
val_freq: !!float 5e3
|
|
||||||
|
|
||||||
#### logger
|
|
||||||
logger:
|
|
||||||
print_freq: 100
|
|
||||||
save_checkpoint_freq: !!float 5e3
|
|
|
@ -11,7 +11,7 @@ import torchvision.transforms.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import options.options as option
|
from utils import options as option
|
||||||
import utils.util as util
|
import utils.util as util
|
||||||
from data import create_dataloader
|
from data import create_dataloader
|
||||||
from models import create_model
|
from models import create_model
|
||||||
|
|
|
@ -1,10 +0,0 @@
|
||||||
# single GPU training (image SR)
|
|
||||||
python train.py -opt options/train/train_SRResNet.yml
|
|
||||||
python train.py -opt options/train/train_SRGAN.yml
|
|
||||||
python train.py -opt options/train/train_ESRGAN.yml
|
|
||||||
|
|
||||||
|
|
||||||
# distributed training (video SR)
|
|
||||||
# 8 GPUs
|
|
||||||
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_woTSA_M.yml --launcher pytorch
|
|
||||||
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_M.yml --launcher pytorch
|
|
|
@ -1,20 +0,0 @@
|
||||||
function [im_h] = backprojection(im_h, im_l, maxIter)
|
|
||||||
|
|
||||||
[row_l, col_l,~] = size(im_l);
|
|
||||||
[row_h, col_h,~] = size(im_h);
|
|
||||||
|
|
||||||
p = fspecial('gaussian', 5, 1);
|
|
||||||
p = p.^2;
|
|
||||||
p = p./sum(p(:));
|
|
||||||
|
|
||||||
im_l = double(im_l);
|
|
||||||
im_h = double(im_h);
|
|
||||||
|
|
||||||
for ii = 1:maxIter
|
|
||||||
im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
|
|
||||||
im_diff = im_l - im_l_s;
|
|
||||||
im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
|
|
||||||
im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
|
|
||||||
im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
|
|
||||||
im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
|
|
||||||
end
|
|
|
@ -1,22 +0,0 @@
|
||||||
clear; close all; clc;
|
|
||||||
|
|
||||||
LR_folder = './LR'; % LR
|
|
||||||
preout_folder = './results'; % pre output
|
|
||||||
save_folder = './results_20bp';
|
|
||||||
filepaths = dir(fullfile(preout_folder, '*.png'));
|
|
||||||
max_iter = 20;
|
|
||||||
|
|
||||||
if ~ exist(save_folder, 'dir')
|
|
||||||
mkdir(save_folder);
|
|
||||||
end
|
|
||||||
|
|
||||||
for idx_im = 1:length(filepaths)
|
|
||||||
fprintf([num2str(idx_im) '\n']);
|
|
||||||
im_name = filepaths(idx_im).name;
|
|
||||||
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
|
|
||||||
im_out = im2double(imread(fullfile(preout_folder, im_name)));
|
|
||||||
%tic
|
|
||||||
im_out = backprojection(im_out, im_LR, max_iter);
|
|
||||||
%toc
|
|
||||||
imwrite(im_out, fullfile(save_folder, im_name));
|
|
||||||
end
|
|
|
@ -1,25 +0,0 @@
|
||||||
clear; close all; clc;
|
|
||||||
|
|
||||||
LR_folder = './LR'; % LR
|
|
||||||
preout_folder = './results'; % pre output
|
|
||||||
save_folder = './results_20if';
|
|
||||||
filepaths = dir(fullfile(preout_folder, '*.png'));
|
|
||||||
max_iter = 20;
|
|
||||||
|
|
||||||
if ~ exist(save_folder, 'dir')
|
|
||||||
mkdir(save_folder);
|
|
||||||
end
|
|
||||||
|
|
||||||
for idx_im = 1:length(filepaths)
|
|
||||||
fprintf([num2str(idx_im) '\n']);
|
|
||||||
im_name = filepaths(idx_im).name;
|
|
||||||
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
|
|
||||||
im_out = im2double(imread(fullfile(preout_folder, im_name)));
|
|
||||||
J = imresize(im_LR,4,'bicubic');
|
|
||||||
%tic
|
|
||||||
for m = 1:max_iter
|
|
||||||
im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic'));
|
|
||||||
end
|
|
||||||
%toc
|
|
||||||
imwrite(im_out, fullfile(save_folder, im_name));
|
|
||||||
end
|
|
|
@ -1,14 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
import os
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import options.options as option
|
from utils import options as option
|
||||||
from data import create_dataloader, create_dataset
|
from data import create_dataloader, create_dataset
|
||||||
import math
|
import math
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision import transforms
|
|
||||||
from utils.fdpl_util import dct_2d, extract_patches_2d
|
from utils.fdpl_util import dct_2d, extract_patches_2d
|
||||||
import random
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||||
from utils.colors import rgb2ycbcr
|
from utils.colors import rgb2ycbcr
|
|
@ -1,27 +0,0 @@
|
||||||
import os.path as osp
|
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
try:
|
|
||||||
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
|
||||||
import models.archs.SRResNet_arch as SRResNet_arch
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth')
|
|
||||||
crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
|
|
||||||
crt_net = crt_model.state_dict()
|
|
||||||
|
|
||||||
for k, v in crt_net.items():
|
|
||||||
if k in pretrained_net and 'upconv1' not in k:
|
|
||||||
crt_net[k] = pretrained_net[k]
|
|
||||||
print('replace ... ', k)
|
|
||||||
|
|
||||||
# x4 -> x3
|
|
||||||
crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
|
|
||||||
crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
|
|
||||||
crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
|
|
||||||
crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
|
|
||||||
crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
|
|
||||||
crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2
|
|
||||||
|
|
||||||
torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')
|
|
|
@ -2,21 +2,14 @@ import os.path as osp
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import options.options as option
|
from utils import options as option
|
||||||
import utils.util as util
|
import utils.util as util
|
||||||
from data.util import bgr2ycbcr
|
|
||||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
|
||||||
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
|
||||||
from switched_conv.switched_conv import compute_attention_specificity
|
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
from models import create_model
|
from models import create_model
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
import models.networks as networks
|
|
||||||
import shutil
|
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,8 @@ import math
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
import options.options as option
|
from utils import util, options as option
|
||||||
from utils import util
|
|
||||||
from data import create_dataloader, create_dataset
|
from data import create_dataloader, create_dataset
|
||||||
from time import time
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from skimage import io
|
from skimage import io
|
||||||
|
|
|
@ -1,9 +0,0 @@
|
||||||
rm gen/*
|
|
||||||
rm hr/*
|
|
||||||
rm lr/*
|
|
||||||
rm pix/*
|
|
||||||
rm ref/*
|
|
||||||
rm genlr/*
|
|
||||||
rm genmr/*
|
|
||||||
rm lr_precorrupt/*
|
|
||||||
rm ref/*
|
|
|
@ -1,21 +1,5 @@
|
||||||
import os.path as osp
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import argparse
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import os
|
|
||||||
import options.options as option
|
|
||||||
import utils.util as util
|
|
||||||
from data.util import bgr2ycbcr
|
|
||||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
|
||||||
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
|
||||||
from switched_conv.switched_conv import compute_attention_specificity
|
|
||||||
from data import create_dataset, create_dataloader
|
|
||||||
from models import create_model
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
import models.networks as networks
|
|
||||||
|
|
||||||
class CheckpointFunction(torch.autograd.Function):
|
class CheckpointFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -39,7 +23,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)
|
input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)
|
||||||
return (None, None) + input_grads
|
return (None, None) + input_grads
|
||||||
|
|
||||||
from models.archs.arch_util import ConvGnSilu, UpconvBlock
|
from models.archs.arch_util import ConvGnSilu
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model = nn.Sequential(ConvGnSilu(3, 64, 3, norm=False),
|
model = nn.Sequential(ConvGnSilu(3, 64, 3, norm=False),
|
||||||
|
|
|
@ -3,16 +3,14 @@ import math
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from data.data_sampler import DistIterSampler
|
from data.data_sampler import DistIterSampler
|
||||||
|
|
||||||
import options.options as option
|
from utils import util, options as option
|
||||||
from utils import util
|
|
||||||
from data import create_dataloader, create_dataset
|
from data import create_dataloader, create_dataset
|
||||||
from models import create_model
|
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
|
||||||
|
@ -159,7 +157,7 @@ def main():
|
||||||
assert train_loader is not None
|
assert train_loader is not None
|
||||||
|
|
||||||
#### create model
|
#### create model
|
||||||
model = create_model(opt)
|
model = ExtensibleTrainer(opt)
|
||||||
|
|
||||||
#### resume training
|
#### resume training
|
||||||
if resume_state:
|
if resume_state:
|
||||||
|
|
|
@ -3,16 +3,14 @@ import math
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from data.data_sampler import DistIterSampler
|
from data.data_sampler import DistIterSampler
|
||||||
|
|
||||||
import options.options as option
|
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from utils import util
|
from utils import util, options as option
|
||||||
from data import create_dataloader, create_dataset
|
from data import create_dataloader, create_dataset
|
||||||
from models import create_model
|
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
|
||||||
|
@ -159,7 +157,7 @@ def main():
|
||||||
assert train_loader is not None
|
assert train_loader is not None
|
||||||
|
|
||||||
#### create model
|
#### create model
|
||||||
model = create_model(opt)
|
model = ExtensibleTrainer(opt)
|
||||||
|
|
||||||
#### resume training
|
#### resume training
|
||||||
if resume_state:
|
if resume_state:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Tool that can be used to add a new layer into an existing model save file. Primarily useful for "progressive"
|
# Tool that can be used to add a new layer into an existing model save file. Primarily useful for "progressive"
|
||||||
# models which can be trained piecemeal.
|
# models which can be trained piecemeal.
|
||||||
|
|
||||||
import options.options as option
|
from utils import options as option
|
||||||
from models import create_model
|
from models import create_model
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
import functools
|
import functools
|
||||||
import torch
|
import torch
|
||||||
import options.options as option
|
from utils import options as option
|
||||||
from models.networks import define_G
|
from models.networks import define_G
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user