forked from mrq/DL-Art-School
Codebase cleanup
Removed a lot of legacy stuff I have no intent on using again. Plan is to shape this repo into something more extensible (get it? hah!)
This commit is contained in:
parent
e620fc05ba
commit
24792bdb4f
|
@ -1,4 +0,0 @@
|
|||
[style]
|
||||
BASED_ON_STYLE = pep8
|
||||
COLUMN_LIMIT = 100
|
||||
SPLIT_BEFORE_NAMED_ASSIGNS = false
|
|
@ -1,124 +0,0 @@
|
|||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
import lmdb
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import data.util as util
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
|
||||
class DownsampleDataset(data.Dataset):
|
||||
"""
|
||||
Reads an unpaired HQ and LQ image. Clips both images to the expected input sizes of the model. Produces a
|
||||
downsampled LQ image from the HQ image and feeds that as well.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(DownsampleDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.data_type = self.opt['data_type']
|
||||
self.paths_LQ, self.paths_GT = None, None
|
||||
self.sizes_LQ, self.sizes_GT = None, None
|
||||
self.LQ_env, self.GT_env = None, None # environments for lmdb
|
||||
self.doCrop = self.opt['doCrop']
|
||||
|
||||
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
||||
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
||||
|
||||
self.data_sz_mismatch_ok = opt['mismatched_Data_OK']
|
||||
assert self.paths_GT, 'Error: GT path is empty.'
|
||||
assert self.paths_LQ, 'LQ is required for downsampling.'
|
||||
if not self.data_sz_mismatch_ok:
|
||||
assert len(self.paths_LQ) == len(
|
||||
self.paths_GT
|
||||
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
|
||||
len(self.paths_LQ), len(self.paths_GT))
|
||||
self.random_scale_list = [1]
|
||||
|
||||
def _init_lmdb(self):
|
||||
# https://github.com/chainer/chainermn/issues/129
|
||||
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
||||
self._init_lmdb()
|
||||
scale = self.opt['scale']
|
||||
GT_size = self.opt['target_size'] * scale
|
||||
|
||||
# get GT image
|
||||
GT_path = self.paths_GT[index % len(self.paths_GT)]
|
||||
resolution = [int(s) for s in self.sizes_GT[index].split('_')
|
||||
] if self.data_type == 'lmdb' else None
|
||||
img_GT = util.read_img(self.GT_env, GT_path, resolution)
|
||||
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
|
||||
img_GT = util.modcrop(img_GT, scale)
|
||||
if self.opt['color']: # change color space if necessary
|
||||
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
|
||||
|
||||
# get LQ image
|
||||
LQ_path = self.paths_LQ[index % len(self.paths_LQ)]
|
||||
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
|
||||
] if self.data_type == 'lmdb' else None
|
||||
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
|
||||
|
||||
if self.opt['phase'] == 'train':
|
||||
H, W, _ = img_GT.shape
|
||||
assert H >= GT_size and W >= GT_size
|
||||
|
||||
H, W, C = img_LQ.shape
|
||||
LQ_size = GT_size // scale
|
||||
|
||||
if self.doCrop:
|
||||
# randomly crop
|
||||
rnd_h = random.randint(0, max(0, H - LQ_size))
|
||||
rnd_w = random.randint(0, max(0, W - LQ_size))
|
||||
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
||||
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
|
||||
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
||||
else:
|
||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
|
||||
self.opt['use_rot'])
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
if img_GT.shape[2] == 3:
|
||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
|
||||
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# HQ needs to go to a PIL image to perform the compression-artifact transformation.
|
||||
H, W, _ = img_GT.shape
|
||||
img_GT = (img_GT * 255).astype(np.uint8)
|
||||
img_GT = Image.fromarray(img_GT)
|
||||
if self.opt['use_compression_artifacts']:
|
||||
qf = random.randrange(15, 100)
|
||||
corruption_buffer = BytesIO()
|
||||
img_GT.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
|
||||
corruption_buffer.seek(0)
|
||||
img_GT = Image.open(corruption_buffer)
|
||||
# Generate a downsampled image from HQ for feature and PIX losses.
|
||||
img_Downsampled = F.resize(img_GT, (int(H / scale), int(W / scale)))
|
||||
|
||||
img_GT = F.to_tensor(img_GT)
|
||||
img_Downsampled = F.to_tensor(img_Downsampled)
|
||||
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
|
||||
|
||||
# This may seem really messed up, but let me explain:
|
||||
# The goal is to re-use existing code as much as possible. SRGAN_model was coded to supersample, not downsample,
|
||||
# but it can be retrofitted. To do so, we need to "trick" it. In this case the "input" is the HQ image and the
|
||||
# "output" is the LQ image. SRGAN_model will be using a Generator and a Discriminator which already know this,
|
||||
# we just need to trick its logic into following this rules.
|
||||
# Do this by setting LQ(which is the input into the models)=img_GT and GT(which is the expected outpuut)=img_LQ.
|
||||
# PIX is used as a reference for the pixel loss. Use the manually downsampled image for this.
|
||||
return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
||||
|
||||
def __len__(self):
|
||||
return max(len(self.paths_GT), len(self.paths_LQ))
|
|
@ -1,239 +0,0 @@
|
|||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
import lmdb
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import data.util as util
|
||||
from PIL import Image, ImageOps
|
||||
from io import BytesIO
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
|
||||
class LQGTDataset(data.Dataset):
|
||||
"""
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
|
||||
If only GT images are provided, generate LQ images on-the-fly.
|
||||
"""
|
||||
def get_lq_path(self, i):
|
||||
which_lq = random.randint(0, len(self.paths_LQ)-1)
|
||||
return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])]
|
||||
|
||||
def __init__(self, opt):
|
||||
super(LQGTDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.data_type = self.opt['data_type']
|
||||
self.paths_LQ, self.paths_GT = None, None
|
||||
self.sizes_LQ, self.sizes_GT = None, None
|
||||
self.paths_PIX, self.sizes_PIX = None, None
|
||||
self.paths_GAN, self.sizes_GAN = None, None
|
||||
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs
|
||||
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
|
||||
|
||||
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights'])
|
||||
if 'dataroot_LQ' in opt.keys():
|
||||
self.paths_LQ = []
|
||||
if isinstance(opt['dataroot_LQ'], list):
|
||||
# Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and
|
||||
# we want the model to learn them all.
|
||||
for dr_lq in opt['dataroot_LQ']:
|
||||
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq)
|
||||
self.paths_LQ.append(lq_path)
|
||||
else:
|
||||
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
||||
self.paths_LQ.append(lq_path)
|
||||
self.doCrop = opt['doCrop']
|
||||
if 'dataroot_PIX' in opt.keys():
|
||||
self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX'])
|
||||
# dataroot_GAN is an alternative source of LR images specifically for use in computing the GAN loss, where
|
||||
# LR and HR do not need to be paired.
|
||||
if 'dataroot_GAN' in opt.keys():
|
||||
self.paths_GAN, self.sizes_GAN = util.get_image_paths(self.data_type, opt['dataroot_GAN'])
|
||||
print('loaded %i images for use in training GAN only.' % (self.sizes_GAN,))
|
||||
|
||||
assert self.paths_GT, 'Error: GT path is empty.'
|
||||
self.random_scale_list = [1]
|
||||
|
||||
def _init_lmdb(self):
|
||||
# https://github.com/chainer/chainermn/issues/129
|
||||
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
if 'dataroot_PIX' in self.opt.keys():
|
||||
self.PIX_env = lmdb.open(self.opt['dataroot_PIX'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
|
||||
def motion_blur(self, image, size, angle):
|
||||
k = np.zeros((size, size), dtype=np.float32)
|
||||
k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32)
|
||||
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size))
|
||||
k = k * (1.0 / np.sum(k))
|
||||
return cv2.filter2D(image, -1, k)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
||||
self._init_lmdb()
|
||||
GT_path, LQ_path = None, None
|
||||
scale = self.opt['scale']
|
||||
GT_size = self.opt['target_size']
|
||||
|
||||
# get GT image
|
||||
GT_path = self.paths_GT[index % len(self.paths_GT)]
|
||||
resolution = [int(s) for s in self.sizes_GT[index].split('_')
|
||||
] if self.data_type == 'lmdb' else None
|
||||
img_GT = util.read_img(self.GT_env, GT_path, resolution)
|
||||
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
|
||||
img_GT = util.modcrop(img_GT, scale)
|
||||
if self.opt['color']: # change color space if necessary
|
||||
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
|
||||
|
||||
# get the pix image
|
||||
if self.paths_PIX is not None:
|
||||
PIX_path = self.paths_PIX[index % len(self.paths_PIX)]
|
||||
img_PIX = util.read_img(self.PIX_env, PIX_path, resolution)
|
||||
if self.opt['color']: # change color space if necessary
|
||||
img_PIX = util.channel_convert(img_PIX.shape[2], self.opt['color'], [img_PIX])[0]
|
||||
else:
|
||||
img_PIX = img_GT
|
||||
|
||||
# get LQ image
|
||||
if self.paths_LQ:
|
||||
LQ_path = self.get_lq_path(index)
|
||||
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
|
||||
] if self.data_type == 'lmdb' else None
|
||||
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
|
||||
else: # down-sampling on-the-fly
|
||||
# randomly scale during training
|
||||
if self.opt['phase'] == 'train':
|
||||
random_scale = random.choice(self.random_scale_list)
|
||||
H_s, W_s, _ = img_GT.shape
|
||||
|
||||
def _mod(n, random_scale, scale, thres):
|
||||
rlt = int(n * random_scale)
|
||||
rlt = (rlt // scale) * scale
|
||||
return thres if rlt < thres else rlt
|
||||
|
||||
H_s = _mod(H_s, random_scale, scale, GT_size)
|
||||
W_s = _mod(W_s, random_scale, scale, GT_size)
|
||||
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
|
||||
if img_GT.ndim == 2:
|
||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
H, W, _ = img_GT.shape
|
||||
|
||||
# using matlab imresize
|
||||
if scale == 1:
|
||||
img_LQ = img_GT
|
||||
else:
|
||||
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
||||
if img_LQ.ndim == 2:
|
||||
img_LQ = np.expand_dims(img_LQ, axis=2)
|
||||
|
||||
img_GAN = None
|
||||
if self.paths_GAN:
|
||||
GAN_path = self.paths_GAN[index % self.sizes_GAN]
|
||||
img_GAN = util.read_img(self.LQ_env, GAN_path)
|
||||
|
||||
# Enforce force_resize constraints.
|
||||
h, w, _ = img_LQ.shape
|
||||
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
|
||||
h, w = (h - h % self.force_multiple), (w - w % self.force_multiple)
|
||||
img_LQ = img_LQ[:h, :w, :]
|
||||
h *= scale
|
||||
w *= scale
|
||||
img_GT = img_GT[:h, :w, :]
|
||||
img_PIX = img_PIX[:h, :w, :]
|
||||
|
||||
if self.opt['phase'] == 'train':
|
||||
H, W, _ = img_GT.shape
|
||||
assert H >= GT_size and W >= GT_size
|
||||
|
||||
H, W, C = img_LQ.shape
|
||||
LQ_size = GT_size // scale
|
||||
|
||||
if self.doCrop:
|
||||
# randomly crop
|
||||
rnd_h = random.randint(0, max(0, H - LQ_size))
|
||||
rnd_w = random.randint(0, max(0, W - LQ_size))
|
||||
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
||||
if img_GAN is not None:
|
||||
img_GAN = img_GAN[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
||||
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
|
||||
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
||||
img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
||||
else:
|
||||
if img_LQ.shape[0] != LQ_size:
|
||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||
if img_GAN is not None:
|
||||
img_GAN = cv2.resize(img_GAN, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||
if img_GT.shape[0] != GT_size:
|
||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||
if img_PIX.shape[0] != GT_size:
|
||||
img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if 'doResizeLoss' in self.opt.keys() and self.opt['doResizeLoss']:
|
||||
r = random.randrange(0, 10)
|
||||
if r > 5:
|
||||
img_LQ = cv2.resize(img_LQ, (int(LQ_size/2), int(LQ_size/2)), interpolation=cv2.INTER_LINEAR)
|
||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_LQ, img_GT, img_PIX = util.augment([img_LQ, img_GT, img_PIX], self.opt['use_flip'],
|
||||
self.opt['use_rot'])
|
||||
|
||||
if self.opt['use_blurring']:
|
||||
# Pick randomly between gaussian, motion, or no blur.
|
||||
blur_det = random.randint(0, 100)
|
||||
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
|
||||
if blur_det < 40:
|
||||
blur_sig = int(random.randrange(0, blur_magnitude))
|
||||
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
|
||||
elif blur_det < 70:
|
||||
img_LQ = self.motion_blur(img_LQ, random.randrange(1, blur_magnitude * 3), random.randint(0, 360))
|
||||
|
||||
|
||||
if self.opt['color']: # change color space if necessary
|
||||
img_LQ = util.channel_convert(C, self.opt['color'],
|
||||
[img_LQ])[0] # TODO during val no definition
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
if img_GT.shape[2] == 3:
|
||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
|
||||
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
|
||||
if img_GAN is not None:
|
||||
img_GAN = cv2.cvtColor(img_GAN, cv2.COLOR_BGR2RGB)
|
||||
img_PIX = cv2.cvtColor(img_PIX, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
|
||||
img_LQ = (img_LQ * 255).astype(np.uint8)
|
||||
img_LQ = Image.fromarray(img_LQ)
|
||||
if self.opt['use_compression_artifacts'] and random.random() > .25:
|
||||
qf = random.randrange(10, 70)
|
||||
corruption_buffer = BytesIO()
|
||||
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
|
||||
corruption_buffer.seek(0)
|
||||
img_LQ = Image.open(corruption_buffer)
|
||||
|
||||
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
|
||||
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
|
||||
|
||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float()
|
||||
img_LQ = F.to_tensor(img_LQ)
|
||||
if img_GAN is not None:
|
||||
img_GAN = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GAN, (2, 0, 1)))).float()
|
||||
|
||||
if 'lq_noise' in self.opt.keys():
|
||||
lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255
|
||||
img_LQ += lq_noise
|
||||
|
||||
if LQ_path is None:
|
||||
LQ_path = GT_path
|
||||
d = {'LQ': img_LQ, 'GT': img_GT, 'ref': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
||||
if img_GAN is not None:
|
||||
d['GAN'] = img_GAN
|
||||
return d
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths_GT)
|
|
@ -1,71 +0,0 @@
|
|||
import numpy as np
|
||||
import lmdb
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import data.util as util
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image
|
||||
import os.path as osp
|
||||
import cv2
|
||||
|
||||
|
||||
class LQDataset(data.Dataset):
|
||||
'''Read LQ images only in the test phase.'''
|
||||
|
||||
def __init__(self, opt):
|
||||
super(LQDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.data_type = self.opt['data_type']
|
||||
if 'start_at' in self.opt.keys():
|
||||
self.start_at = self.opt['start_at']
|
||||
else:
|
||||
self.start_at = 0
|
||||
self.vertical_splits = self.opt['vertical_splits']
|
||||
self.paths_LQ, self.paths_GT = None, None
|
||||
self.LQ_env = None # environment for lmdb
|
||||
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
|
||||
|
||||
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
||||
self.paths_LQ = self.paths_LQ[self.start_at:]
|
||||
assert self.paths_LQ, 'Error: LQ paths are empty.'
|
||||
|
||||
def _init_lmdb(self):
|
||||
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.data_type == 'lmdb' and self.LQ_env is None:
|
||||
self._init_lmdb()
|
||||
if self.vertical_splits > 0:
|
||||
actual_index = int(index / self.vertical_splits)
|
||||
else:
|
||||
actual_index = index
|
||||
|
||||
# get LQ image
|
||||
LQ_path = self.paths_LQ[actual_index]
|
||||
img_LQ = Image.open(LQ_path)
|
||||
if self.vertical_splits > 0:
|
||||
w, h = img_LQ.size
|
||||
split_index = (index % self.vertical_splits)
|
||||
w_per_split = int(w / self.vertical_splits)
|
||||
left = w_per_split * split_index
|
||||
img_LQ = F.crop(img_LQ, 0, left, h, w_per_split)
|
||||
|
||||
# Enforce force_resize constraints.
|
||||
h, w = img_LQ.size
|
||||
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
|
||||
h, w = (w - w % self.force_multiple), (h - h % self.force_multiple)
|
||||
img_LQ = img_LQ.resize((w, h))
|
||||
|
||||
img_LQ = F.to_tensor(img_LQ)
|
||||
|
||||
img_name = osp.splitext(osp.basename(LQ_path))[0]
|
||||
LQ_path = LQ_path.replace(img_name, img_name + "_%i" % (index % self.vertical_splits))
|
||||
|
||||
return {'LQ': img_LQ, 'LQ_path': LQ_path}
|
||||
|
||||
def __len__(self):
|
||||
if self.vertical_splits > 0:
|
||||
return len(self.paths_LQ) * self.vertical_splits
|
||||
else:
|
||||
return len(self.paths_LQ)
|
|
@ -29,14 +29,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
|||
def create_dataset(dataset_opt):
|
||||
mode = dataset_opt['mode']
|
||||
# datasets for image restoration
|
||||
if mode == 'LQ':
|
||||
from data.LQ_dataset import LQDataset as D
|
||||
elif mode == 'LQGT':
|
||||
from data.LQGT_dataset import LQGTDataset as D
|
||||
# datasets for image corruption
|
||||
elif mode == 'downsample':
|
||||
from data.Downsample_dataset import DownsampleDataset as D
|
||||
elif mode == 'fullimage':
|
||||
if mode == 'fullimage':
|
||||
from data.full_image_dataset import FullImageDataset as D
|
||||
elif mode == 'single_image_extensible':
|
||||
from data.single_image_dataset import SingleImageDataset as D
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,171 +0,0 @@
|
|||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
import models.networks as networks
|
||||
import models.lr_scheduler as lr_scheduler
|
||||
from .base_model import BaseModel
|
||||
from models.loss import CharbonnierLoss
|
||||
from apex import amp
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
||||
class SRModel(BaseModel):
|
||||
def __init__(self, opt):
|
||||
super(SRModel, self).__init__(opt)
|
||||
|
||||
if opt['dist']:
|
||||
self.rank = torch.distributed.get_rank()
|
||||
else:
|
||||
self.rank = -1 # non dist training
|
||||
train_opt = opt['train']
|
||||
|
||||
# define network and load pretrained models
|
||||
self.netG = amp.initialize(networks.define_G(opt).to(self.device), opt_level=self.amp_level)
|
||||
if opt['dist']:
|
||||
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
||||
elif opt['gpu_ids'] is not None:
|
||||
self.netG = DataParallel(self.netG)
|
||||
# print network
|
||||
self.print_network()
|
||||
self.load()
|
||||
|
||||
if self.is_train:
|
||||
self.netG.train()
|
||||
|
||||
# loss
|
||||
loss_type = train_opt['pixel_criterion']
|
||||
if loss_type == 'l1':
|
||||
self.cri_pix = nn.L1Loss().to(self.device)
|
||||
elif loss_type == 'l2':
|
||||
self.cri_pix = nn.MSELoss().to(self.device)
|
||||
elif loss_type == 'cb':
|
||||
self.cri_pix = CharbonnierLoss().to(self.device)
|
||||
else:
|
||||
raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
|
||||
self.l_pix_w = train_opt['pixel_weight']
|
||||
|
||||
# optimizers
|
||||
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
||||
optim_params = []
|
||||
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
||||
if v.requires_grad:
|
||||
optim_params.append(v)
|
||||
else:
|
||||
if self.rank <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
||||
weight_decay=wd_G,
|
||||
betas=(train_opt['beta1'], train_opt['beta2']))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
|
||||
# schedulers
|
||||
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||
for optimizer in self.optimizers:
|
||||
self.schedulers.append(
|
||||
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
|
||||
restarts=train_opt['restarts'],
|
||||
weights=train_opt['restart_weights'],
|
||||
gamma=train_opt['lr_gamma'],
|
||||
clear_state=train_opt['clear_state']))
|
||||
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
||||
for optimizer in self.optimizers:
|
||||
self.schedulers.append(
|
||||
lr_scheduler.CosineAnnealingLR_Restart(
|
||||
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
|
||||
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
|
||||
else:
|
||||
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
||||
|
||||
self.log_dict = OrderedDict()
|
||||
|
||||
def feed_data(self, data, need_GT=True):
|
||||
self.var_L = data['LQ'].to(self.device) # LQ
|
||||
if need_GT:
|
||||
self.real_H = data['GT'].to(self.device) # GT
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
self.optimizer_G.zero_grad()
|
||||
self.fake_H = self.netG(self.var_L)
|
||||
l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
|
||||
l_pix.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
# set log
|
||||
self.log_dict['l_pix'] = l_pix.item()
|
||||
|
||||
def test(self):
|
||||
self.netG.eval()
|
||||
with torch.no_grad():
|
||||
self.fake_H = self.netG(self.var_L)
|
||||
self.netG.train()
|
||||
|
||||
def test_x8(self):
|
||||
# from https://github.com/thstkdgus35/EDSR-PyTorch
|
||||
self.netG.eval()
|
||||
|
||||
def _transform(v, op):
|
||||
# if self.precision != 'single': v = v.float()
|
||||
v2np = v.data.cpu().numpy()
|
||||
if op == 'v':
|
||||
tfnp = v2np[:, :, :, ::-1].copy()
|
||||
elif op == 'h':
|
||||
tfnp = v2np[:, :, ::-1, :].copy()
|
||||
elif op == 't':
|
||||
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
|
||||
|
||||
ret = torch.Tensor(tfnp).to(self.device)
|
||||
# if self.precision == 'half': ret = ret.half()
|
||||
|
||||
return ret
|
||||
|
||||
lr_list = [self.var_L]
|
||||
for tf in 'v', 'h', 't':
|
||||
lr_list.extend([_transform(t, tf) for t in lr_list])
|
||||
with torch.no_grad():
|
||||
sr_list = [self.netG(aug) for aug in lr_list]
|
||||
for i in range(len(sr_list)):
|
||||
if i > 3:
|
||||
sr_list[i] = _transform(sr_list[i], 't')
|
||||
if i % 4 > 1:
|
||||
sr_list[i] = _transform(sr_list[i], 'h')
|
||||
if (i % 4) % 2 == 1:
|
||||
sr_list[i] = _transform(sr_list[i], 'v')
|
||||
|
||||
output_cat = torch.cat(sr_list, dim=0)
|
||||
self.fake_H = output_cat.mean(dim=0, keepdim=True)
|
||||
self.netG.train()
|
||||
|
||||
def get_current_log(self):
|
||||
return self.log_dict
|
||||
|
||||
def get_current_visuals(self, need_GT=True):
|
||||
out_dict = OrderedDict()
|
||||
out_dict['LQ'] = self.var_L.detach().float().cpu()
|
||||
out_dict['rlt'] = self.fake_H.detach().float().cpu()
|
||||
if need_GT:
|
||||
out_dict['GT'] = self.real_H.detach().float().cpu()
|
||||
return out_dict
|
||||
|
||||
def print_network(self):
|
||||
s, n = self.get_network_description(self.netG)
|
||||
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
|
||||
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
||||
self.netG.module.__class__.__name__)
|
||||
else:
|
||||
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
||||
if self.rank <= 0:
|
||||
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||
logger.info(s)
|
||||
|
||||
def load(self):
|
||||
load_path_G = self.opt['path']['pretrain_model_G']
|
||||
if load_path_G is not None:
|
||||
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
|
||||
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
|
||||
|
||||
def save(self, iter_label):
|
||||
self.save_network(self.netG, 'G', iter_label)
|
|
@ -1,22 +0,0 @@
|
|||
import logging
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
||||
def create_model(opt):
|
||||
model = opt['model']
|
||||
# image restoration
|
||||
if model == 'sr': # PSNR-oriented super resolution
|
||||
from .SR_model import SRModel as M
|
||||
elif model == 'srgan' or model == 'corruptgan' or model == 'spsrgan':
|
||||
from .SRGAN_model import SRGANModel as M
|
||||
elif model == 'feat':
|
||||
from .feature_model import FeatureModel as M
|
||||
elif model == 'spsr':
|
||||
from .SPSR_model import SPSRModel as M
|
||||
elif model == 'extensibletrainer':
|
||||
from .ExtensibleTrainer import ExtensibleTrainer as M
|
||||
else:
|
||||
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
|
||||
m = M(opt)
|
||||
logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
|
||||
return m
|
|
@ -1,39 +0,0 @@
|
|||
from torch.autograd import Function, Variable
|
||||
from torch.nn.modules.module import Module
|
||||
import channelnorm_cuda
|
||||
|
||||
class ChannelNormFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input1, norm_deg=2):
|
||||
assert input1.is_contiguous()
|
||||
b, _, h, w = input1.size()
|
||||
output = input1.new(b, 1, h, w).zero_()
|
||||
|
||||
channelnorm_cuda.forward(input1, output, norm_deg)
|
||||
ctx.save_for_backward(input1, output)
|
||||
ctx.norm_deg = norm_deg
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input1, output = ctx.saved_tensors
|
||||
|
||||
grad_input1 = Variable(input1.new(input1.size()).zero_())
|
||||
|
||||
channelnorm_cuda.backward(input1, output, grad_output.data,
|
||||
grad_input1.data, ctx.norm_deg)
|
||||
|
||||
return grad_input1, None
|
||||
|
||||
|
||||
class ChannelNorm(Module):
|
||||
|
||||
def __init__(self, norm_deg=2):
|
||||
super(ChannelNorm, self).__init__()
|
||||
self.norm_deg = norm_deg
|
||||
|
||||
def forward(self, input1):
|
||||
return ChannelNormFunction.apply(input1, self.norm_deg)
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
#include <torch/torch.h>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include "channelnorm_kernel.cuh"
|
||||
|
||||
int channelnorm_cuda_forward(
|
||||
at::Tensor& input1,
|
||||
at::Tensor& output,
|
||||
int norm_deg) {
|
||||
|
||||
channelnorm_kernel_forward(input1, output, norm_deg);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
int channelnorm_cuda_backward(
|
||||
at::Tensor& input1,
|
||||
at::Tensor& output,
|
||||
at::Tensor& gradOutput,
|
||||
at::Tensor& gradInput1,
|
||||
int norm_deg) {
|
||||
|
||||
channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg);
|
||||
return 1;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)");
|
||||
m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)");
|
||||
}
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
void channelnorm_kernel_forward(
|
||||
at::Tensor& input1,
|
||||
at::Tensor& output,
|
||||
int norm_deg);
|
||||
|
||||
|
||||
void channelnorm_kernel_backward(
|
||||
at::Tensor& input1,
|
||||
at::Tensor& output,
|
||||
at::Tensor& gradOutput,
|
||||
at::Tensor& gradInput1,
|
||||
int norm_deg);
|
|
@ -1,28 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import torch
|
||||
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
cxx_args = ['-std=c++11']
|
||||
|
||||
nvcc_args = [
|
||||
'-gencode', 'arch=compute_52,code=sm_52',
|
||||
'-gencode', 'arch=compute_60,code=sm_60',
|
||||
'-gencode', 'arch=compute_61,code=sm_61',
|
||||
'-gencode', 'arch=compute_70,code=sm_70',
|
||||
'-gencode', 'arch=compute_70,code=compute_70'
|
||||
]
|
||||
|
||||
setup(
|
||||
name='channelnorm_cuda',
|
||||
ext_modules=[
|
||||
CUDAExtension('channelnorm_cuda', [
|
||||
'channelnorm_cuda.cc',
|
||||
'channelnorm_kernel.cu'
|
||||
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
|
@ -1,61 +0,0 @@
|
|||
import torch
|
||||
from torch.nn.modules.module import Module
|
||||
from torch.autograd import Function
|
||||
import correlation_cuda
|
||||
|
||||
class CorrelationFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input1, input2, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
|
||||
ctx.save_for_backward(input1, input2)
|
||||
|
||||
ctx.pad_size = pad_size
|
||||
ctx.kernel_size = kernel_size
|
||||
ctx.max_displacement = max_displacement
|
||||
ctx.stride1 = stride1
|
||||
ctx.stride2 = stride2
|
||||
ctx.corr_multiply = corr_multiply
|
||||
|
||||
with torch.cuda.device_of(input1):
|
||||
rbot1 = input1.new()
|
||||
rbot2 = input2.new()
|
||||
output = input1.new()
|
||||
|
||||
correlation_cuda.forward(input1, input2, rbot1, rbot2, output,
|
||||
ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input1, input2 = ctx.saved_tensors
|
||||
|
||||
with torch.cuda.device_of(input1):
|
||||
rbot1 = input1.new()
|
||||
rbot2 = input2.new()
|
||||
|
||||
grad_input1 = input1.new()
|
||||
grad_input2 = input2.new()
|
||||
|
||||
correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
|
||||
ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply)
|
||||
|
||||
return grad_input1, grad_input2, None, None, None, None, None, None
|
||||
|
||||
|
||||
class Correlation(Module):
|
||||
def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
|
||||
super(Correlation, self).__init__()
|
||||
self.pad_size = pad_size
|
||||
self.kernel_size = kernel_size
|
||||
self.max_displacement = max_displacement
|
||||
self.stride1 = stride1
|
||||
self.stride2 = stride2
|
||||
self.corr_multiply = corr_multiply
|
||||
|
||||
def forward(self, input1, input2):
|
||||
|
||||
result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply)
|
||||
|
||||
return result
|
||||
|
|
@ -1,173 +0,0 @@
|
|||
#include <torch/torch.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <stdio.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "correlation_cuda_kernel.cuh"
|
||||
|
||||
int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
|
||||
int pad_size,
|
||||
int kernel_size,
|
||||
int max_displacement,
|
||||
int stride1,
|
||||
int stride2,
|
||||
int corr_type_multiply)
|
||||
{
|
||||
|
||||
int batchSize = input1.size(0);
|
||||
|
||||
int nInputChannels = input1.size(1);
|
||||
int inputHeight = input1.size(2);
|
||||
int inputWidth = input1.size(3);
|
||||
|
||||
int kernel_radius = (kernel_size - 1) / 2;
|
||||
int border_radius = kernel_radius + max_displacement;
|
||||
|
||||
int paddedInputHeight = inputHeight + 2 * pad_size;
|
||||
int paddedInputWidth = inputWidth + 2 * pad_size;
|
||||
|
||||
int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
|
||||
|
||||
int outputHeight = ceil(static_cast<float>(paddedInputHeight - 2 * border_radius) / static_cast<float>(stride1));
|
||||
int outputwidth = ceil(static_cast<float>(paddedInputWidth - 2 * border_radius) / static_cast<float>(stride1));
|
||||
|
||||
rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
|
||||
rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
|
||||
output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
|
||||
|
||||
rInput1.fill_(0);
|
||||
rInput2.fill_(0);
|
||||
output.fill_(0);
|
||||
|
||||
int success = correlation_forward_cuda_kernel(
|
||||
output,
|
||||
output.size(0),
|
||||
output.size(1),
|
||||
output.size(2),
|
||||
output.size(3),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
output.stride(3),
|
||||
input1,
|
||||
input1.size(1),
|
||||
input1.size(2),
|
||||
input1.size(3),
|
||||
input1.stride(0),
|
||||
input1.stride(1),
|
||||
input1.stride(2),
|
||||
input1.stride(3),
|
||||
input2,
|
||||
input2.size(1),
|
||||
input2.stride(0),
|
||||
input2.stride(1),
|
||||
input2.stride(2),
|
||||
input2.stride(3),
|
||||
rInput1,
|
||||
rInput2,
|
||||
pad_size,
|
||||
kernel_size,
|
||||
max_displacement,
|
||||
stride1,
|
||||
stride2,
|
||||
corr_type_multiply,
|
||||
at::cuda::getCurrentCUDAStream()
|
||||
//at::globalContext().getCurrentCUDAStream()
|
||||
);
|
||||
|
||||
//check for errors
|
||||
if (!success) {
|
||||
AT_ERROR("CUDA call failed");
|
||||
}
|
||||
|
||||
return 1;
|
||||
|
||||
}
|
||||
|
||||
int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput,
|
||||
at::Tensor& gradInput1, at::Tensor& gradInput2,
|
||||
int pad_size,
|
||||
int kernel_size,
|
||||
int max_displacement,
|
||||
int stride1,
|
||||
int stride2,
|
||||
int corr_type_multiply)
|
||||
{
|
||||
|
||||
int batchSize = input1.size(0);
|
||||
int nInputChannels = input1.size(1);
|
||||
int paddedInputHeight = input1.size(2)+ 2 * pad_size;
|
||||
int paddedInputWidth = input1.size(3)+ 2 * pad_size;
|
||||
|
||||
int height = input1.size(2);
|
||||
int width = input1.size(3);
|
||||
|
||||
rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
|
||||
rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
|
||||
gradInput1.resize_({batchSize, nInputChannels, height, width});
|
||||
gradInput2.resize_({batchSize, nInputChannels, height, width});
|
||||
|
||||
rInput1.fill_(0);
|
||||
rInput2.fill_(0);
|
||||
gradInput1.fill_(0);
|
||||
gradInput2.fill_(0);
|
||||
|
||||
int success = correlation_backward_cuda_kernel(gradOutput,
|
||||
gradOutput.size(0),
|
||||
gradOutput.size(1),
|
||||
gradOutput.size(2),
|
||||
gradOutput.size(3),
|
||||
gradOutput.stride(0),
|
||||
gradOutput.stride(1),
|
||||
gradOutput.stride(2),
|
||||
gradOutput.stride(3),
|
||||
input1,
|
||||
input1.size(1),
|
||||
input1.size(2),
|
||||
input1.size(3),
|
||||
input1.stride(0),
|
||||
input1.stride(1),
|
||||
input1.stride(2),
|
||||
input1.stride(3),
|
||||
input2,
|
||||
input2.stride(0),
|
||||
input2.stride(1),
|
||||
input2.stride(2),
|
||||
input2.stride(3),
|
||||
gradInput1,
|
||||
gradInput1.stride(0),
|
||||
gradInput1.stride(1),
|
||||
gradInput1.stride(2),
|
||||
gradInput1.stride(3),
|
||||
gradInput2,
|
||||
gradInput2.size(1),
|
||||
gradInput2.stride(0),
|
||||
gradInput2.stride(1),
|
||||
gradInput2.stride(2),
|
||||
gradInput2.stride(3),
|
||||
rInput1,
|
||||
rInput2,
|
||||
pad_size,
|
||||
kernel_size,
|
||||
max_displacement,
|
||||
stride1,
|
||||
stride2,
|
||||
corr_type_multiply,
|
||||
at::cuda::getCurrentCUDAStream()
|
||||
//at::globalContext().getCurrentCUDAStream()
|
||||
);
|
||||
|
||||
if (!success) {
|
||||
AT_ERROR("CUDA call failed");
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
|
||||
m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
|
||||
}
|
||||
|
|
@ -1,91 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
int correlation_forward_cuda_kernel(at::Tensor& output,
|
||||
int ob,
|
||||
int oc,
|
||||
int oh,
|
||||
int ow,
|
||||
int osb,
|
||||
int osc,
|
||||
int osh,
|
||||
int osw,
|
||||
|
||||
at::Tensor& input1,
|
||||
int ic,
|
||||
int ih,
|
||||
int iw,
|
||||
int isb,
|
||||
int isc,
|
||||
int ish,
|
||||
int isw,
|
||||
|
||||
at::Tensor& input2,
|
||||
int gc,
|
||||
int gsb,
|
||||
int gsc,
|
||||
int gsh,
|
||||
int gsw,
|
||||
|
||||
at::Tensor& rInput1,
|
||||
at::Tensor& rInput2,
|
||||
int pad_size,
|
||||
int kernel_size,
|
||||
int max_displacement,
|
||||
int stride1,
|
||||
int stride2,
|
||||
int corr_type_multiply,
|
||||
cudaStream_t stream);
|
||||
|
||||
|
||||
int correlation_backward_cuda_kernel(
|
||||
at::Tensor& gradOutput,
|
||||
int gob,
|
||||
int goc,
|
||||
int goh,
|
||||
int gow,
|
||||
int gosb,
|
||||
int gosc,
|
||||
int gosh,
|
||||
int gosw,
|
||||
|
||||
at::Tensor& input1,
|
||||
int ic,
|
||||
int ih,
|
||||
int iw,
|
||||
int isb,
|
||||
int isc,
|
||||
int ish,
|
||||
int isw,
|
||||
|
||||
at::Tensor& input2,
|
||||
int gsb,
|
||||
int gsc,
|
||||
int gsh,
|
||||
int gsw,
|
||||
|
||||
at::Tensor& gradInput1,
|
||||
int gisb,
|
||||
int gisc,
|
||||
int gish,
|
||||
int gisw,
|
||||
|
||||
at::Tensor& gradInput2,
|
||||
int ggc,
|
||||
int ggsb,
|
||||
int ggsc,
|
||||
int ggsh,
|
||||
int ggsw,
|
||||
|
||||
at::Tensor& rInput1,
|
||||
at::Tensor& rInput2,
|
||||
int pad_size,
|
||||
int kernel_size,
|
||||
int max_displacement,
|
||||
int stride1,
|
||||
int stride2,
|
||||
int corr_type_multiply,
|
||||
cudaStream_t stream);
|
|
@ -1,29 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import torch
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
cxx_args = ['-std=c++11']
|
||||
|
||||
nvcc_args = [
|
||||
'-gencode', 'arch=compute_50,code=sm_50',
|
||||
'-gencode', 'arch=compute_52,code=sm_52',
|
||||
'-gencode', 'arch=compute_60,code=sm_60',
|
||||
'-gencode', 'arch=compute_61,code=sm_61',
|
||||
'-gencode', 'arch=compute_70,code=sm_70',
|
||||
'-gencode', 'arch=compute_70,code=compute_70'
|
||||
]
|
||||
|
||||
setup(
|
||||
name='correlation_cuda',
|
||||
ext_modules=[
|
||||
CUDAExtension('correlation_cuda', [
|
||||
'correlation_cuda.cc',
|
||||
'correlation_cuda_kernel.cu'
|
||||
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
|
@ -1,49 +0,0 @@
|
|||
from torch.nn.modules.module import Module
|
||||
from torch.autograd import Function, Variable
|
||||
import resample2d_cuda
|
||||
|
||||
class Resample2dFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input1, input2, kernel_size=1, bilinear= True):
|
||||
assert input1.is_contiguous()
|
||||
assert input2.is_contiguous()
|
||||
|
||||
ctx.save_for_backward(input1, input2)
|
||||
ctx.kernel_size = kernel_size
|
||||
ctx.bilinear = bilinear
|
||||
|
||||
_, d, _, _ = input1.size()
|
||||
b, _, h, w = input2.size()
|
||||
output = input1.new(b, d, h, w).zero_()
|
||||
|
||||
resample2d_cuda.forward(input1, input2, output, kernel_size, bilinear)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_output = grad_output.contiguous()
|
||||
assert grad_output.is_contiguous()
|
||||
|
||||
input1, input2 = ctx.saved_tensors
|
||||
|
||||
grad_input1 = Variable(input1.new(input1.size()).zero_())
|
||||
grad_input2 = Variable(input1.new(input2.size()).zero_())
|
||||
|
||||
resample2d_cuda.backward(input1, input2, grad_output.data,
|
||||
grad_input1.data, grad_input2.data,
|
||||
ctx.kernel_size, ctx.bilinear)
|
||||
|
||||
return grad_input1, grad_input2, None, None
|
||||
|
||||
class Resample2d(Module):
|
||||
|
||||
def __init__(self, kernel_size=1, bilinear = True):
|
||||
super(Resample2d, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.bilinear = bilinear
|
||||
|
||||
def forward(self, input1, input2):
|
||||
input1_c = input1.contiguous()
|
||||
return Resample2dFunction.apply(input1_c, input2, self.kernel_size, self.bilinear)
|
|
@ -1,32 +0,0 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include "resample2d_kernel.cuh"
|
||||
|
||||
int resample2d_cuda_forward(
|
||||
at::Tensor& input1,
|
||||
at::Tensor& input2,
|
||||
at::Tensor& output,
|
||||
int kernel_size, bool bilinear) {
|
||||
resample2d_kernel_forward(input1, input2, output, kernel_size, bilinear);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int resample2d_cuda_backward(
|
||||
at::Tensor& input1,
|
||||
at::Tensor& input2,
|
||||
at::Tensor& gradOutput,
|
||||
at::Tensor& gradInput1,
|
||||
at::Tensor& gradInput2,
|
||||
int kernel_size, bool bilinear) {
|
||||
resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size, bilinear);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)");
|
||||
m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)");
|
||||
}
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
void resample2d_kernel_forward(
|
||||
at::Tensor& input1,
|
||||
at::Tensor& input2,
|
||||
at::Tensor& output,
|
||||
int kernel_size,
|
||||
bool bilinear);
|
||||
|
||||
void resample2d_kernel_backward(
|
||||
at::Tensor& input1,
|
||||
at::Tensor& input2,
|
||||
at::Tensor& gradOutput,
|
||||
at::Tensor& gradInput1,
|
||||
at::Tensor& gradInput2,
|
||||
int kernel_size,
|
||||
bool bilinear);
|
|
@ -1,29 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import torch
|
||||
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
cxx_args = ['-std=c++11']
|
||||
|
||||
nvcc_args = [
|
||||
'-gencode', 'arch=compute_50,code=sm_50',
|
||||
'-gencode', 'arch=compute_52,code=sm_52',
|
||||
'-gencode', 'arch=compute_60,code=sm_60',
|
||||
'-gencode', 'arch=compute_61,code=sm_61',
|
||||
'-gencode', 'arch=compute_70,code=sm_70',
|
||||
'-gencode', 'arch=compute_70,code=compute_70'
|
||||
]
|
||||
|
||||
setup(
|
||||
name='resample2d_cuda',
|
||||
ext_modules=[
|
||||
CUDAExtension('resample2d_cuda', [
|
||||
'resample2d_cuda.cc',
|
||||
'resample2d_kernel.cu'
|
||||
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
|
@ -49,7 +49,7 @@ class GANLoss(nn.Module):
|
|||
|
||||
|
||||
# Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL
|
||||
# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see data_scripts/compute_fdpl_perceptual_weights.py
|
||||
# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see scripts/compute_fdpl_perceptual_weights.py
|
||||
# In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs).
|
||||
class FDPLLoss(nn.Module):
|
||||
"""
|
||||
|
|
|
@ -1,71 +0,0 @@
|
|||
# Author Masashi Kimura (Convergence Lab.)
|
||||
import torch
|
||||
from torch import optim
|
||||
import math
|
||||
|
||||
class NovoGrad(optim.Optimizer):
|
||||
def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super(NovoGrad, self).__init__(params, defaults)
|
||||
self._lr = lr
|
||||
self._beta1 = betas[0]
|
||||
self._beta2 = betas[1]
|
||||
self._eps = eps
|
||||
self._wd = weight_decay
|
||||
self._grad_averaging = grad_averaging
|
||||
|
||||
self._momentum_initialized = False
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
if not self._momentum_initialized:
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('NovoGrad does not support sparse gradients')
|
||||
|
||||
v = torch.norm(grad)**2
|
||||
m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data
|
||||
state['step'] = 0
|
||||
state['v'] = v
|
||||
state['m'] = m
|
||||
state['grad_ema'] = None
|
||||
self._momentum_initialized = True
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
state['step'] += 1
|
||||
|
||||
step, v, m = state['step'], state['v'], state['m']
|
||||
grad_ema = state['grad_ema']
|
||||
|
||||
grad = p.grad.data
|
||||
g2 = torch.norm(grad)**2
|
||||
grad_ema = g2 if grad_ema is None else grad_ema * \
|
||||
self._beta2 + g2*(1. - self._beta2)
|
||||
grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
|
||||
|
||||
if self._grad_averaging:
|
||||
grad *= (1. - self._beta1)
|
||||
|
||||
g2 = torch.norm(grad)**2
|
||||
v = self._beta2*v + (1. - self._beta2)*g2
|
||||
m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd*p.data)
|
||||
bias_correction1 = 1 - self._beta1 ** step
|
||||
bias_correction2 = 1 - self._beta2 ** step
|
||||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
state['v'], state['m'] = v, m
|
||||
state['grad_ema'] = grad_ema
|
||||
p.data.add_(-step_size, m)
|
||||
return loss
|
|
@ -1,8 +1,6 @@
|
|||
import torch.nn
|
||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||
from data.weight_scheduler import get_scheduler_for_opt
|
||||
from utils.util import checkpoint
|
||||
import torchvision.utils as utils
|
||||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
#from models.steps.recursive_gen_injectors import ImageFlowInjector
|
||||
from models.steps.losses import extract_params_from_state
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import torch
|
|||
from apex import amp
|
||||
from collections import OrderedDict
|
||||
from .injectors import create_injector
|
||||
from models.novograd import NovoGrad
|
||||
from utils.util import recursively_detach
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
name: RRDB_ESRGAN_x4
|
||||
suffix: ~ # add suffix to saved images
|
||||
model: sr
|
||||
distortion: sr
|
||||
scale: 4
|
||||
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
||||
gpu_ids: [0]
|
||||
|
||||
datasets:
|
||||
test_1: # the 1st test dataset
|
||||
name: set5
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/val_set5/Set5
|
||||
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
|
||||
test_2: # the 2st test dataset
|
||||
name: set14
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/val_set14/Set14
|
||||
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: RRDBNet
|
||||
in_nc: 3
|
||||
out_nc: 3
|
||||
nf: 64
|
||||
nb: 23
|
||||
upscale: 4
|
||||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth
|
|
@ -1,26 +0,0 @@
|
|||
name: RRDB_ESRGAN_x4
|
||||
suffix: ~ # add suffix to saved images
|
||||
model: sr
|
||||
distortion: sr
|
||||
scale: 4
|
||||
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
||||
gpu_ids: [0]
|
||||
|
||||
datasets:
|
||||
test_1: # the 1st test dataset
|
||||
name: set14
|
||||
mode: LQ
|
||||
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: RRDBNet
|
||||
in_nc: 3
|
||||
out_nc: 3
|
||||
nf: 64
|
||||
nb: 23
|
||||
upscale: 4
|
||||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth
|
|
@ -1,32 +0,0 @@
|
|||
name: MSRGANx4
|
||||
suffix: ~ # add suffix to saved images
|
||||
model: sr
|
||||
distortion: sr
|
||||
scale: 4
|
||||
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
||||
gpu_ids: [0]
|
||||
|
||||
datasets:
|
||||
test_1: # the 1st test dataset
|
||||
name: set5
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/val_set5/Set5
|
||||
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
|
||||
test_2: # the 2st test dataset
|
||||
name: set14
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/val_set14/Set14
|
||||
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: MSRResNet
|
||||
in_nc: 3
|
||||
out_nc: 3
|
||||
nf: 64
|
||||
nb: 16
|
||||
upscale: 4
|
||||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ../experiments/pretrained_models/MSRGANx4.pth
|
|
@ -1,48 +0,0 @@
|
|||
name: MSRResNetx4
|
||||
suffix: ~ # add suffix to saved images
|
||||
model: sr
|
||||
distortion: sr
|
||||
scale: 4
|
||||
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
||||
gpu_ids: [0]
|
||||
|
||||
datasets:
|
||||
test_1: # the 1st test dataset
|
||||
name: set5
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/val_set5/Set5
|
||||
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
|
||||
test_2: # the 2st test dataset
|
||||
name: set14
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/val_set14/Set14
|
||||
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
||||
test_3:
|
||||
name: bsd100
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/BSD/BSDS100
|
||||
dataroot_LQ: ../datasets/BSD/BSDS100_bicLRx4
|
||||
test_4:
|
||||
name: urban100
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/urban100
|
||||
dataroot_LQ: ../datasets/urban100_bicLRx4
|
||||
test_5:
|
||||
name: div2k100
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/DIV2K100/DIV2K_valid_HR
|
||||
dataroot_LQ: ../datasets/DIV2K100/DIV2K_valid_bicLRx4
|
||||
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: MSRResNet
|
||||
in_nc: 3
|
||||
out_nc: 3
|
||||
nf: 64
|
||||
nb: 16
|
||||
upscale: 4
|
||||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth
|
|
@ -1,80 +0,0 @@
|
|||
#### general settings
|
||||
name: 002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new
|
||||
use_tb_logger: true
|
||||
model: video_base
|
||||
distortion: sr
|
||||
scale: 4
|
||||
gpu_ids: [0,1,2,3,4,5,6,7]
|
||||
|
||||
#### datasets
|
||||
datasets:
|
||||
train:
|
||||
name: REDS
|
||||
mode: REDS
|
||||
interval_list: [1]
|
||||
random_reverse: false
|
||||
border_mode: false
|
||||
dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
|
||||
dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
|
||||
cache_keys: ~
|
||||
|
||||
N_frames: 5
|
||||
use_shuffle: true
|
||||
n_workers: 3 # per GPU
|
||||
batch_size: 32
|
||||
target_size: 256
|
||||
LQ_size: 64
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
color: RGB
|
||||
val:
|
||||
name: REDS4
|
||||
mode: video_test
|
||||
dataroot_GT: ../datasets/REDS4/GT
|
||||
dataroot_LQ: ../datasets/REDS4/sharp_bicubic
|
||||
cache_data: True
|
||||
N_frames: 5
|
||||
padding: new_info
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: EDVR
|
||||
nf: 64
|
||||
nframes: 5
|
||||
groups: 8
|
||||
front_RBs: 5
|
||||
back_RBs: 10
|
||||
predeblur: false
|
||||
HR_in: false
|
||||
w_TSA: true
|
||||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ../experiments/pretrained_models/EDVR_REDS_SR_M_woTSA.pth
|
||||
strict_load: false
|
||||
resume_state: ~
|
||||
|
||||
#### training settings: learning rate scheme, loss
|
||||
train:
|
||||
lr_G: !!float 4e-4
|
||||
lr_scheme: CosineAnnealingLR_Restart
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
niter: 600000
|
||||
ft_tsa_only: 50000
|
||||
warmup_iter: -1 # -1: no warm up
|
||||
T_period: [50000, 100000, 150000, 150000, 150000]
|
||||
restarts: [50000, 150000, 300000, 450000]
|
||||
restart_weights: [1, 1, 1, 1]
|
||||
eta_min: !!float 1e-7
|
||||
|
||||
pixel_criterion: cb
|
||||
pixel_weight: 1.0
|
||||
val_freq: !!float 5e3
|
||||
|
||||
manual_seed: 0
|
||||
|
||||
#### logger
|
||||
logger:
|
||||
print_freq: 100
|
||||
save_checkpoint_freq: !!float 5e3
|
|
@ -1,71 +0,0 @@
|
|||
#### general settings
|
||||
name: 001_EDVRwoTSA_scratch_lr4e-4_600k_REDS_LrCAR4S
|
||||
use_tb_logger: true
|
||||
model: video_base
|
||||
distortion: sr
|
||||
scale: 4
|
||||
gpu_ids: [0,1,2,3,4,5,6,7]
|
||||
|
||||
#### datasets
|
||||
datasets:
|
||||
train:
|
||||
name: REDS
|
||||
mode: REDS
|
||||
interval_list: [1]
|
||||
random_reverse: false
|
||||
border_mode: false
|
||||
dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
|
||||
dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
|
||||
cache_keys: ~
|
||||
|
||||
N_frames: 5
|
||||
use_shuffle: true
|
||||
n_workers: 3 # per GPU
|
||||
batch_size: 32
|
||||
target_size: 256
|
||||
LQ_size: 64
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
color: RGB
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: EDVR
|
||||
nf: 64
|
||||
nframes: 5
|
||||
groups: 8
|
||||
front_RBs: 5
|
||||
back_RBs: 10
|
||||
predeblur: false
|
||||
HR_in: false
|
||||
w_TSA: false
|
||||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ~
|
||||
strict_load: true
|
||||
resume_state: ~
|
||||
|
||||
#### training settings: learning rate scheme, loss
|
||||
train:
|
||||
lr_G: !!float 4e-4
|
||||
lr_scheme: CosineAnnealingLR_Restart
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
niter: 600000
|
||||
warmup_iter: -1 # -1: no warm up
|
||||
T_period: [150000, 150000, 150000, 150000]
|
||||
restarts: [150000, 300000, 450000]
|
||||
restart_weights: [1, 1, 1]
|
||||
eta_min: !!float 1e-7
|
||||
|
||||
pixel_criterion: cb
|
||||
pixel_weight: 1.0
|
||||
val_freq: !!float 5e3
|
||||
|
||||
manual_seed: 0
|
||||
|
||||
#### logger
|
||||
logger:
|
||||
print_freq: 100
|
||||
save_checkpoint_freq: !!float 5e3
|
|
@ -1,82 +0,0 @@
|
|||
#### general settings
|
||||
name: 003_RRDB_ESRGANx4_DIV2K
|
||||
use_tb_logger: true
|
||||
model: srgan
|
||||
distortion: sr
|
||||
scale: 4
|
||||
gpu_ids: [0]
|
||||
amp_opt_level: O1
|
||||
|
||||
#### datasets
|
||||
datasets:
|
||||
train:
|
||||
name: DIV2K
|
||||
mode: LQGT
|
||||
dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub
|
||||
dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4
|
||||
|
||||
use_shuffle: true
|
||||
n_workers: 16 # per GPU
|
||||
batch_size: 16
|
||||
target_size: 128
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
color: RGB
|
||||
val:
|
||||
name: div2kval
|
||||
mode: LQGT
|
||||
dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr
|
||||
dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: ResGen
|
||||
nf: 256
|
||||
network_D:
|
||||
which_model_D: discriminator_resnet_passthrough
|
||||
nf: 42
|
||||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ~
|
||||
strict_load: true
|
||||
resume_state: ~
|
||||
|
||||
#### training settings: learning rate scheme, loss
|
||||
train:
|
||||
lr_G: !!float 1e-4
|
||||
weight_decay_G: 0
|
||||
beta1_G: 0.9
|
||||
beta2_G: 0.99
|
||||
lr_D: !!float 1e-4
|
||||
weight_decay_D: 0
|
||||
beta1_D: 0.9
|
||||
beta2_D: 0.99
|
||||
lr_scheme: MultiStepLR
|
||||
|
||||
niter: 400000
|
||||
warmup_iter: -1 # no warm up
|
||||
lr_steps: [50000, 100000, 200000, 300000]
|
||||
lr_gamma: 0.5
|
||||
mega_batch_factor: 1
|
||||
|
||||
pixel_criterion: l1
|
||||
pixel_weight: !!float 1e-2
|
||||
feature_criterion: l1
|
||||
feature_weight: 1
|
||||
feature_weight_decay: .98
|
||||
feature_weight_decay_steps: 500
|
||||
feature_weight_minimum: .1
|
||||
gan_type: gan # gan | ragan
|
||||
gan_weight: !!float 5e-3
|
||||
|
||||
D_update_ratio: 2
|
||||
D_init_iters: 0
|
||||
|
||||
manual_seed: 10
|
||||
val_freq: !!float 5e2
|
||||
|
||||
#### logger
|
||||
logger:
|
||||
print_freq: 50
|
||||
save_checkpoint_freq: !!float 5e2
|
|
@ -1,85 +0,0 @@
|
|||
#### general settings
|
||||
name: esrgan_res
|
||||
use_tb_logger: true
|
||||
model: srgan
|
||||
distortion: sr
|
||||
scale: 4
|
||||
gpu_ids: [0]
|
||||
amp_opt_level: O1
|
||||
|
||||
#### datasets
|
||||
datasets:
|
||||
train:
|
||||
name: DIV2K
|
||||
mode: LQGT
|
||||
dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub
|
||||
dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4
|
||||
|
||||
use_shuffle: true
|
||||
n_workers: 0 # per GPU
|
||||
batch_size: 24
|
||||
target_size: 128
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
color: RGB
|
||||
val:
|
||||
name: div2kval
|
||||
mode: LQGT
|
||||
dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr
|
||||
dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: ResGen
|
||||
nf: 256
|
||||
nb_denoiser: 2
|
||||
nb_upsampler: 28
|
||||
network_D:
|
||||
which_model_D: discriminator_resnet_passthrough
|
||||
nf: 42
|
||||
|
||||
#### path
|
||||
path:
|
||||
#pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth
|
||||
#pretrain_model_D: ~
|
||||
strict_load: true
|
||||
resume_state: ../experiments/esrgan_res/training_state/15500.state
|
||||
|
||||
#### training settings: learning rate scheme, loss
|
||||
train:
|
||||
lr_G: !!float 1e-4
|
||||
weight_decay_G: 0
|
||||
beta1_G: 0.9
|
||||
beta2_G: 0.99
|
||||
lr_D: !!float 1e-4
|
||||
weight_decay_D: 0
|
||||
beta1_D: 0.9
|
||||
beta2_D: 0.99
|
||||
lr_scheme: MultiStepLR
|
||||
|
||||
niter: 400000
|
||||
warmup_iter: -1 # no warm up
|
||||
lr_steps: [20000, 40000, 50000, 60000]
|
||||
lr_gamma: 0.5
|
||||
mega_batch_factor: 2
|
||||
|
||||
pixel_criterion: l1
|
||||
pixel_weight: !!float 1e-2
|
||||
feature_criterion: l1
|
||||
feature_weight: 1
|
||||
feature_weight_decay: 1
|
||||
feature_weight_decay_steps: 500
|
||||
feature_weight_minimum: 1
|
||||
gan_type: gan # gan | ragan
|
||||
gan_weight: !!float 1e-2
|
||||
|
||||
D_update_ratio: 2
|
||||
D_init_iters: -1
|
||||
|
||||
manual_seed: 10
|
||||
val_freq: !!float 5e2
|
||||
|
||||
#### logger
|
||||
logger:
|
||||
print_freq: 50
|
||||
save_checkpoint_freq: !!float 5e2
|
|
@ -1,85 +0,0 @@
|
|||
# Not exactly the same as SRGAN in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
|
||||
# With 16 Residual blocks w/o BN
|
||||
|
||||
#### general settings
|
||||
name: 002_SRGANx4_MSRResNetx4Ini_DIV2K
|
||||
use_tb_logger: true
|
||||
model: srgan
|
||||
distortion: sr
|
||||
scale: 4
|
||||
gpu_ids: [1]
|
||||
|
||||
#### datasets
|
||||
datasets:
|
||||
train:
|
||||
name: DIV2K
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
|
||||
dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
|
||||
|
||||
use_shuffle: true
|
||||
n_workers: 6 # per GPU
|
||||
batch_size: 16
|
||||
target_size: 128
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
color: RGB
|
||||
val:
|
||||
name: val_set14
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/val_set14/Set14
|
||||
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: MSRResNet
|
||||
in_nc: 3
|
||||
out_nc: 3
|
||||
nf: 64
|
||||
nb: 16
|
||||
upscale: 4
|
||||
network_D:
|
||||
which_model_D: discriminator_vgg_128
|
||||
in_nc: 3
|
||||
nf: 64
|
||||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth
|
||||
strict_load: true
|
||||
resume_state: ~
|
||||
|
||||
#### training settings: learning rate scheme, loss
|
||||
train:
|
||||
lr_G: !!float 1e-4
|
||||
weight_decay_G: 0
|
||||
beta1_G: 0.9
|
||||
beta2_G: 0.99
|
||||
lr_D: !!float 1e-4
|
||||
weight_decay_D: 0
|
||||
beta1_D: 0.9
|
||||
beta2_D: 0.99
|
||||
lr_scheme: MultiStepLR
|
||||
|
||||
niter: 400000
|
||||
warmup_iter: -1 # no warm up
|
||||
lr_steps: [50000, 100000, 200000, 300000]
|
||||
lr_gamma: 0.5
|
||||
|
||||
pixel_criterion: l1
|
||||
pixel_weight: !!float 1e-2
|
||||
feature_criterion: l1
|
||||
feature_weight: 1
|
||||
gan_type: gan # gan | ragan
|
||||
gan_weight: !!float 5e-3
|
||||
|
||||
D_update_ratio: 1
|
||||
D_init_iters: 0
|
||||
|
||||
manual_seed: 10
|
||||
val_freq: !!float 5e3
|
||||
|
||||
#### logger
|
||||
logger:
|
||||
print_freq: 100
|
||||
save_checkpoint_freq: !!float 5e3
|
|
@ -1,70 +0,0 @@
|
|||
# Not exactly the same as SRResNet in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
|
||||
# With 16 Residual blocks w/o BN
|
||||
|
||||
#### general settings
|
||||
name: 001_MSRResNetx4_scratch_DIV2K
|
||||
use_tb_logger: true
|
||||
model: sr
|
||||
distortion: sr
|
||||
scale: 4
|
||||
gpu_ids: [0]
|
||||
|
||||
#### datasets
|
||||
datasets:
|
||||
train:
|
||||
name: DIV2K
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
|
||||
dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
|
||||
|
||||
use_shuffle: true
|
||||
n_workers: 6 # per GPU
|
||||
batch_size: 16
|
||||
target_size: 128
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
color: RGB
|
||||
val:
|
||||
name: val_set5
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/val_set5/Set5
|
||||
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: MSRResNet
|
||||
in_nc: 3
|
||||
out_nc: 3
|
||||
nf: 64
|
||||
nb: 16
|
||||
upscale: 4
|
||||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ~
|
||||
strict_load: true
|
||||
resume_state: ~
|
||||
|
||||
#### training settings: learning rate scheme, loss
|
||||
train:
|
||||
lr_G: !!float 2e-4
|
||||
lr_scheme: CosineAnnealingLR_Restart
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
niter: 1000000
|
||||
warmup_iter: -1 # no warm up
|
||||
T_period: [250000, 250000, 250000, 250000]
|
||||
restarts: [250000, 500000, 750000]
|
||||
restart_weights: [1, 1, 1]
|
||||
eta_min: !!float 1e-7
|
||||
|
||||
pixel_criterion: l1
|
||||
pixel_weight: 1.0
|
||||
|
||||
manual_seed: 10
|
||||
val_freq: !!float 5e3
|
||||
|
||||
#### logger
|
||||
logger:
|
||||
print_freq: 100
|
||||
save_checkpoint_freq: !!float 5e3
|
|
@ -11,7 +11,7 @@ import torchvision.transforms.functional as F
|
|||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import options.options as option
|
||||
from utils import options as option
|
||||
import utils.util as util
|
||||
from data import create_dataloader
|
||||
from models import create_model
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
# single GPU training (image SR)
|
||||
python train.py -opt options/train/train_SRResNet.yml
|
||||
python train.py -opt options/train/train_SRGAN.yml
|
||||
python train.py -opt options/train/train_ESRGAN.yml
|
||||
|
||||
|
||||
# distributed training (video SR)
|
||||
# 8 GPUs
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_woTSA_M.yml --launcher pytorch
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_M.yml --launcher pytorch
|
|
@ -1,20 +0,0 @@
|
|||
function [im_h] = backprojection(im_h, im_l, maxIter)
|
||||
|
||||
[row_l, col_l,~] = size(im_l);
|
||||
[row_h, col_h,~] = size(im_h);
|
||||
|
||||
p = fspecial('gaussian', 5, 1);
|
||||
p = p.^2;
|
||||
p = p./sum(p(:));
|
||||
|
||||
im_l = double(im_l);
|
||||
im_h = double(im_h);
|
||||
|
||||
for ii = 1:maxIter
|
||||
im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
|
||||
im_diff = im_l - im_l_s;
|
||||
im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
|
||||
im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
|
||||
im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
|
||||
im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
|
||||
end
|
|
@ -1,22 +0,0 @@
|
|||
clear; close all; clc;
|
||||
|
||||
LR_folder = './LR'; % LR
|
||||
preout_folder = './results'; % pre output
|
||||
save_folder = './results_20bp';
|
||||
filepaths = dir(fullfile(preout_folder, '*.png'));
|
||||
max_iter = 20;
|
||||
|
||||
if ~ exist(save_folder, 'dir')
|
||||
mkdir(save_folder);
|
||||
end
|
||||
|
||||
for idx_im = 1:length(filepaths)
|
||||
fprintf([num2str(idx_im) '\n']);
|
||||
im_name = filepaths(idx_im).name;
|
||||
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
|
||||
im_out = im2double(imread(fullfile(preout_folder, im_name)));
|
||||
%tic
|
||||
im_out = backprojection(im_out, im_LR, max_iter);
|
||||
%toc
|
||||
imwrite(im_out, fullfile(save_folder, im_name));
|
||||
end
|
|
@ -1,25 +0,0 @@
|
|||
clear; close all; clc;
|
||||
|
||||
LR_folder = './LR'; % LR
|
||||
preout_folder = './results'; % pre output
|
||||
save_folder = './results_20if';
|
||||
filepaths = dir(fullfile(preout_folder, '*.png'));
|
||||
max_iter = 20;
|
||||
|
||||
if ~ exist(save_folder, 'dir')
|
||||
mkdir(save_folder);
|
||||
end
|
||||
|
||||
for idx_im = 1:length(filepaths)
|
||||
fprintf([num2str(idx_im) '\n']);
|
||||
im_name = filepaths(idx_im).name;
|
||||
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
|
||||
im_out = im2double(imread(fullfile(preout_folder, im_name)));
|
||||
J = imresize(im_LR,4,'bicubic');
|
||||
%tic
|
||||
for m = 1:max_iter
|
||||
im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic'));
|
||||
end
|
||||
%toc
|
||||
imwrite(im_out, fullfile(save_folder, im_name));
|
||||
end
|
|
@ -1,14 +1,10 @@
|
|||
import torch
|
||||
import os
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import options.options as option
|
||||
from utils import options as option
|
||||
from data import create_dataloader, create_dataset
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from torchvision import transforms
|
||||
from utils.fdpl_util import dct_2d, extract_patches_2d
|
||||
import random
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
from utils.colors import rgb2ycbcr
|
|
@ -1,27 +0,0 @@
|
|||
import os.path as osp
|
||||
import sys
|
||||
import torch
|
||||
try:
|
||||
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||
import models.archs.SRResNet_arch as SRResNet_arch
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth')
|
||||
crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
|
||||
crt_net = crt_model.state_dict()
|
||||
|
||||
for k, v in crt_net.items():
|
||||
if k in pretrained_net and 'upconv1' not in k:
|
||||
crt_net[k] = pretrained_net[k]
|
||||
print('replace ... ', k)
|
||||
|
||||
# x4 -> x3
|
||||
crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
|
||||
crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
|
||||
crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
|
||||
crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
|
||||
crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
|
||||
crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2
|
||||
|
||||
torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')
|
|
@ -2,21 +2,14 @@ import os.path as osp
|
|||
import logging
|
||||
import time
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
import os
|
||||
import options.options as option
|
||||
from utils import options as option
|
||||
import utils.util as util
|
||||
from data.util import bgr2ycbcr
|
||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
||||
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
||||
from switched_conv.switched_conv import compute_attention_specificity
|
||||
from data import create_dataset, create_dataloader
|
||||
from models import create_model
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import models.networks as networks
|
||||
import shutil
|
||||
import torchvision
|
||||
|
||||
|
|
@ -5,10 +5,8 @@ import math
|
|||
import argparse
|
||||
import random
|
||||
import torch
|
||||
import options.options as option
|
||||
from utils import util
|
||||
from utils import util, options as option
|
||||
from data import create_dataloader, create_dataset
|
||||
from time import time
|
||||
from tqdm import tqdm
|
||||
from skimage import io
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
rm gen/*
|
||||
rm hr/*
|
||||
rm lr/*
|
||||
rm pix/*
|
||||
rm ref/*
|
||||
rm genlr/*
|
||||
rm genmr/*
|
||||
rm lr_precorrupt/*
|
||||
rm ref/*
|
|
@ -1,21 +1,5 @@
|
|||
import os.path as osp
|
||||
import logging
|
||||
import time
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
import os
|
||||
import options.options as option
|
||||
import utils.util as util
|
||||
from data.util import bgr2ycbcr
|
||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
||||
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
||||
from switched_conv.switched_conv import compute_attention_specificity
|
||||
from data import create_dataset, create_dataloader
|
||||
from models import create_model
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import models.networks as networks
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
@ -39,7 +23,7 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)
|
||||
return (None, None) + input_grads
|
||||
|
||||
from models.archs.arch_util import ConvGnSilu, UpconvBlock
|
||||
from models.archs.arch_util import ConvGnSilu
|
||||
import torch.nn as nn
|
||||
if __name__ == "__main__":
|
||||
model = nn.Sequential(ConvGnSilu(3, 64, 3, norm=False),
|
||||
|
|
|
@ -3,16 +3,14 @@ import math
|
|||
import argparse
|
||||
import random
|
||||
import logging
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from data.data_sampler import DistIterSampler
|
||||
|
||||
import options.options as option
|
||||
from utils import util
|
||||
from utils import util, options as option
|
||||
from data import create_dataloader, create_dataset
|
||||
from models import create_model
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from time import time
|
||||
|
||||
|
||||
|
@ -159,7 +157,7 @@ def main():
|
|||
assert train_loader is not None
|
||||
|
||||
#### create model
|
||||
model = create_model(opt)
|
||||
model = ExtensibleTrainer(opt)
|
||||
|
||||
#### resume training
|
||||
if resume_state:
|
||||
|
|
|
@ -3,16 +3,14 @@ import math
|
|||
import argparse
|
||||
import random
|
||||
import logging
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from data.data_sampler import DistIterSampler
|
||||
|
||||
import options.options as option
|
||||
from utils import util
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from utils import util, options as option
|
||||
from data import create_dataloader, create_dataset
|
||||
from models import create_model
|
||||
from time import time
|
||||
|
||||
|
||||
|
@ -159,7 +157,7 @@ def main():
|
|||
assert train_loader is not None
|
||||
|
||||
#### create model
|
||||
model = create_model(opt)
|
||||
model = ExtensibleTrainer(opt)
|
||||
|
||||
#### resume training
|
||||
if resume_state:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Tool that can be used to add a new layer into an existing model save file. Primarily useful for "progressive"
|
||||
# models which can be trained piecemeal.
|
||||
|
||||
import options.options as option
|
||||
from utils import options as option
|
||||
from models import create_model
|
||||
import torch
|
||||
import os
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import argparse
|
||||
import functools
|
||||
import torch
|
||||
import options.options as option
|
||||
from utils import options as option
|
||||
from models.networks import define_G
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user