diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index 9f0e0622..00000000 --- a/.style.yapf +++ /dev/null @@ -1,4 +0,0 @@ -[style] -BASED_ON_STYLE = pep8 -COLUMN_LIMIT = 100 -SPLIT_BEFORE_NAMED_ASSIGNS = false \ No newline at end of file diff --git a/codes/data/Downsample_dataset.py b/codes/data/Downsample_dataset.py deleted file mode 100644 index e6861ef5..00000000 --- a/codes/data/Downsample_dataset.py +++ /dev/null @@ -1,124 +0,0 @@ -import random -import numpy as np -import cv2 -import lmdb -import torch -import torch.utils.data as data -import data.util as util -from PIL import Image -from io import BytesIO -import torchvision.transforms.functional as F - - -class DownsampleDataset(data.Dataset): - """ - Reads an unpaired HQ and LQ image. Clips both images to the expected input sizes of the model. Produces a - downsampled LQ image from the HQ image and feeds that as well. - """ - - def __init__(self, opt): - super(DownsampleDataset, self).__init__() - self.opt = opt - self.data_type = self.opt['data_type'] - self.paths_LQ, self.paths_GT = None, None - self.sizes_LQ, self.sizes_GT = None, None - self.LQ_env, self.GT_env = None, None # environments for lmdb - self.doCrop = self.opt['doCrop'] - - self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) - self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) - - self.data_sz_mismatch_ok = opt['mismatched_Data_OK'] - assert self.paths_GT, 'Error: GT path is empty.' - assert self.paths_LQ, 'LQ is required for downsampling.' - if not self.data_sz_mismatch_ok: - assert len(self.paths_LQ) == len( - self.paths_GT - ), 'GT and LQ datasets have different number of images - {}, {}.'.format( - len(self.paths_LQ), len(self.paths_GT)) - self.random_scale_list = [1] - - def _init_lmdb(self): - # https://github.com/chainer/chainermn/issues/129 - self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, - meminit=False) - self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, - meminit=False) - - def __getitem__(self, index): - if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): - self._init_lmdb() - scale = self.opt['scale'] - GT_size = self.opt['target_size'] * scale - - # get GT image - GT_path = self.paths_GT[index % len(self.paths_GT)] - resolution = [int(s) for s in self.sizes_GT[index].split('_') - ] if self.data_type == 'lmdb' else None - img_GT = util.read_img(self.GT_env, GT_path, resolution) - if self.opt['phase'] != 'train': # modcrop in the validation / test phase - img_GT = util.modcrop(img_GT, scale) - if self.opt['color']: # change color space if necessary - img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] - - # get LQ image - LQ_path = self.paths_LQ[index % len(self.paths_LQ)] - resolution = [int(s) for s in self.sizes_LQ[index].split('_') - ] if self.data_type == 'lmdb' else None - img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) - - if self.opt['phase'] == 'train': - H, W, _ = img_GT.shape - assert H >= GT_size and W >= GT_size - - H, W, C = img_LQ.shape - LQ_size = GT_size // scale - - if self.doCrop: - # randomly crop - rnd_h = random.randint(0, max(0, H - LQ_size)) - rnd_w = random.randint(0, max(0, W - LQ_size)) - img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] - rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) - img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] - else: - img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) - img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) - - # augmentation - flip, rotate - img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], - self.opt['use_rot']) - - # BGR to RGB, HWC to CHW, numpy to tensor - if img_GT.shape[2] == 3: - img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB) - img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB) - - # HQ needs to go to a PIL image to perform the compression-artifact transformation. - H, W, _ = img_GT.shape - img_GT = (img_GT * 255).astype(np.uint8) - img_GT = Image.fromarray(img_GT) - if self.opt['use_compression_artifacts']: - qf = random.randrange(15, 100) - corruption_buffer = BytesIO() - img_GT.save(corruption_buffer, "JPEG", quality=qf, optimice=True) - corruption_buffer.seek(0) - img_GT = Image.open(corruption_buffer) - # Generate a downsampled image from HQ for feature and PIX losses. - img_Downsampled = F.resize(img_GT, (int(H / scale), int(W / scale))) - - img_GT = F.to_tensor(img_GT) - img_Downsampled = F.to_tensor(img_Downsampled) - img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() - - # This may seem really messed up, but let me explain: - # The goal is to re-use existing code as much as possible. SRGAN_model was coded to supersample, not downsample, - # but it can be retrofitted. To do so, we need to "trick" it. In this case the "input" is the HQ image and the - # "output" is the LQ image. SRGAN_model will be using a Generator and a Discriminator which already know this, - # we just need to trick its logic into following this rules. - # Do this by setting LQ(which is the input into the models)=img_GT and GT(which is the expected outpuut)=img_LQ. - # PIX is used as a reference for the pixel loss. Use the manually downsampled image for this. - return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path} - - def __len__(self): - return max(len(self.paths_GT), len(self.paths_LQ)) diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py deleted file mode 100644 index d1fa1f50..00000000 --- a/codes/data/LQGT_dataset.py +++ /dev/null @@ -1,239 +0,0 @@ -import random -import numpy as np -import cv2 -import lmdb -import torch -import torch.utils.data as data -import data.util as util -from PIL import Image, ImageOps -from io import BytesIO -import torchvision.transforms.functional as F - - -class LQGTDataset(data.Dataset): - """ - Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs. - If only GT images are provided, generate LQ images on-the-fly. - """ - def get_lq_path(self, i): - which_lq = random.randint(0, len(self.paths_LQ)-1) - return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])] - - def __init__(self, opt): - super(LQGTDataset, self).__init__() - self.opt = opt - self.data_type = self.opt['data_type'] - self.paths_LQ, self.paths_GT = None, None - self.sizes_LQ, self.sizes_GT = None, None - self.paths_PIX, self.sizes_PIX = None, None - self.paths_GAN, self.sizes_GAN = None, None - self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs - self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1 - - self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights']) - if 'dataroot_LQ' in opt.keys(): - self.paths_LQ = [] - if isinstance(opt['dataroot_LQ'], list): - # Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and - # we want the model to learn them all. - for dr_lq in opt['dataroot_LQ']: - lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq) - self.paths_LQ.append(lq_path) - else: - lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) - self.paths_LQ.append(lq_path) - self.doCrop = opt['doCrop'] - if 'dataroot_PIX' in opt.keys(): - self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX']) - # dataroot_GAN is an alternative source of LR images specifically for use in computing the GAN loss, where - # LR and HR do not need to be paired. - if 'dataroot_GAN' in opt.keys(): - self.paths_GAN, self.sizes_GAN = util.get_image_paths(self.data_type, opt['dataroot_GAN']) - print('loaded %i images for use in training GAN only.' % (self.sizes_GAN,)) - - assert self.paths_GT, 'Error: GT path is empty.' - self.random_scale_list = [1] - - def _init_lmdb(self): - # https://github.com/chainer/chainermn/issues/129 - self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, - meminit=False) - self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, - meminit=False) - if 'dataroot_PIX' in self.opt.keys(): - self.PIX_env = lmdb.open(self.opt['dataroot_PIX'], readonly=True, lock=False, readahead=False, - meminit=False) - - def motion_blur(self, image, size, angle): - k = np.zeros((size, size), dtype=np.float32) - k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32) - k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size)) - k = k * (1.0 / np.sum(k)) - return cv2.filter2D(image, -1, k) - - def __getitem__(self, index): - if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): - self._init_lmdb() - GT_path, LQ_path = None, None - scale = self.opt['scale'] - GT_size = self.opt['target_size'] - - # get GT image - GT_path = self.paths_GT[index % len(self.paths_GT)] - resolution = [int(s) for s in self.sizes_GT[index].split('_') - ] if self.data_type == 'lmdb' else None - img_GT = util.read_img(self.GT_env, GT_path, resolution) - if self.opt['phase'] != 'train': # modcrop in the validation / test phase - img_GT = util.modcrop(img_GT, scale) - if self.opt['color']: # change color space if necessary - img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] - - # get the pix image - if self.paths_PIX is not None: - PIX_path = self.paths_PIX[index % len(self.paths_PIX)] - img_PIX = util.read_img(self.PIX_env, PIX_path, resolution) - if self.opt['color']: # change color space if necessary - img_PIX = util.channel_convert(img_PIX.shape[2], self.opt['color'], [img_PIX])[0] - else: - img_PIX = img_GT - - # get LQ image - if self.paths_LQ: - LQ_path = self.get_lq_path(index) - resolution = [int(s) for s in self.sizes_LQ[index].split('_') - ] if self.data_type == 'lmdb' else None - img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) - else: # down-sampling on-the-fly - # randomly scale during training - if self.opt['phase'] == 'train': - random_scale = random.choice(self.random_scale_list) - H_s, W_s, _ = img_GT.shape - - def _mod(n, random_scale, scale, thres): - rlt = int(n * random_scale) - rlt = (rlt // scale) * scale - return thres if rlt < thres else rlt - - H_s = _mod(H_s, random_scale, scale, GT_size) - W_s = _mod(W_s, random_scale, scale, GT_size) - img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR) - if img_GT.ndim == 2: - img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) - - H, W, _ = img_GT.shape - - # using matlab imresize - if scale == 1: - img_LQ = img_GT - else: - img_LQ = util.imresize_np(img_GT, 1 / scale, True) - if img_LQ.ndim == 2: - img_LQ = np.expand_dims(img_LQ, axis=2) - - img_GAN = None - if self.paths_GAN: - GAN_path = self.paths_GAN[index % self.sizes_GAN] - img_GAN = util.read_img(self.LQ_env, GAN_path) - - # Enforce force_resize constraints. - h, w, _ = img_LQ.shape - if h % self.force_multiple != 0 or w % self.force_multiple != 0: - h, w = (h - h % self.force_multiple), (w - w % self.force_multiple) - img_LQ = img_LQ[:h, :w, :] - h *= scale - w *= scale - img_GT = img_GT[:h, :w, :] - img_PIX = img_PIX[:h, :w, :] - - if self.opt['phase'] == 'train': - H, W, _ = img_GT.shape - assert H >= GT_size and W >= GT_size - - H, W, C = img_LQ.shape - LQ_size = GT_size // scale - - if self.doCrop: - # randomly crop - rnd_h = random.randint(0, max(0, H - LQ_size)) - rnd_w = random.randint(0, max(0, W - LQ_size)) - img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] - if img_GAN is not None: - img_GAN = img_GAN[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] - rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) - img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] - img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] - else: - if img_LQ.shape[0] != LQ_size: - img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) - if img_GAN is not None: - img_GAN = cv2.resize(img_GAN, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) - if img_GT.shape[0] != GT_size: - img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) - if img_PIX.shape[0] != GT_size: - img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) - - if 'doResizeLoss' in self.opt.keys() and self.opt['doResizeLoss']: - r = random.randrange(0, 10) - if r > 5: - img_LQ = cv2.resize(img_LQ, (int(LQ_size/2), int(LQ_size/2)), interpolation=cv2.INTER_LINEAR) - img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) - - # augmentation - flip, rotate - img_LQ, img_GT, img_PIX = util.augment([img_LQ, img_GT, img_PIX], self.opt['use_flip'], - self.opt['use_rot']) - - if self.opt['use_blurring']: - # Pick randomly between gaussian, motion, or no blur. - blur_det = random.randint(0, 100) - blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude'] - if blur_det < 40: - blur_sig = int(random.randrange(0, blur_magnitude)) - img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig) - elif blur_det < 70: - img_LQ = self.motion_blur(img_LQ, random.randrange(1, blur_magnitude * 3), random.randint(0, 360)) - - - if self.opt['color']: # change color space if necessary - img_LQ = util.channel_convert(C, self.opt['color'], - [img_LQ])[0] # TODO during val no definition - - # BGR to RGB, HWC to CHW, numpy to tensor - if img_GT.shape[2] == 3: - img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB) - img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB) - if img_GAN is not None: - img_GAN = cv2.cvtColor(img_GAN, cv2.COLOR_BGR2RGB) - img_PIX = cv2.cvtColor(img_PIX, cv2.COLOR_BGR2RGB) - - # LQ needs to go to a PIL image to perform the compression-artifact transformation. - img_LQ = (img_LQ * 255).astype(np.uint8) - img_LQ = Image.fromarray(img_LQ) - if self.opt['use_compression_artifacts'] and random.random() > .25: - qf = random.randrange(10, 70) - corruption_buffer = BytesIO() - img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True) - corruption_buffer.seek(0) - img_LQ = Image.open(corruption_buffer) - - if 'grayscale' in self.opt.keys() and self.opt['grayscale']: - img_LQ = ImageOps.grayscale(img_LQ).convert('RGB') - - img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() - img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float() - img_LQ = F.to_tensor(img_LQ) - if img_GAN is not None: - img_GAN = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GAN, (2, 0, 1)))).float() - - if 'lq_noise' in self.opt.keys(): - lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255 - img_LQ += lq_noise - - if LQ_path is None: - LQ_path = GT_path - d = {'LQ': img_LQ, 'GT': img_GT, 'ref': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path} - if img_GAN is not None: - d['GAN'] = img_GAN - return d - - def __len__(self): - return len(self.paths_GT) diff --git a/codes/data/LQ_dataset.py b/codes/data/LQ_dataset.py deleted file mode 100644 index e730408b..00000000 --- a/codes/data/LQ_dataset.py +++ /dev/null @@ -1,71 +0,0 @@ -import numpy as np -import lmdb -import torch -import torch.utils.data as data -import data.util as util -import torchvision.transforms.functional as F -from PIL import Image -import os.path as osp -import cv2 - - -class LQDataset(data.Dataset): - '''Read LQ images only in the test phase.''' - - def __init__(self, opt): - super(LQDataset, self).__init__() - self.opt = opt - self.data_type = self.opt['data_type'] - if 'start_at' in self.opt.keys(): - self.start_at = self.opt['start_at'] - else: - self.start_at = 0 - self.vertical_splits = self.opt['vertical_splits'] - self.paths_LQ, self.paths_GT = None, None - self.LQ_env = None # environment for lmdb - self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1 - - self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) - self.paths_LQ = self.paths_LQ[self.start_at:] - assert self.paths_LQ, 'Error: LQ paths are empty.' - - def _init_lmdb(self): - self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, - meminit=False) - - def __getitem__(self, index): - if self.data_type == 'lmdb' and self.LQ_env is None: - self._init_lmdb() - if self.vertical_splits > 0: - actual_index = int(index / self.vertical_splits) - else: - actual_index = index - - # get LQ image - LQ_path = self.paths_LQ[actual_index] - img_LQ = Image.open(LQ_path) - if self.vertical_splits > 0: - w, h = img_LQ.size - split_index = (index % self.vertical_splits) - w_per_split = int(w / self.vertical_splits) - left = w_per_split * split_index - img_LQ = F.crop(img_LQ, 0, left, h, w_per_split) - - # Enforce force_resize constraints. - h, w = img_LQ.size - if h % self.force_multiple != 0 or w % self.force_multiple != 0: - h, w = (w - w % self.force_multiple), (h - h % self.force_multiple) - img_LQ = img_LQ.resize((w, h)) - - img_LQ = F.to_tensor(img_LQ) - - img_name = osp.splitext(osp.basename(LQ_path))[0] - LQ_path = LQ_path.replace(img_name, img_name + "_%i" % (index % self.vertical_splits)) - - return {'LQ': img_LQ, 'LQ_path': LQ_path} - - def __len__(self): - if self.vertical_splits > 0: - return len(self.paths_LQ) * self.vertical_splits - else: - return len(self.paths_LQ) diff --git a/codes/data/__init__.py b/codes/data/__init__.py index cf6434bd..b08d7d9f 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -29,14 +29,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): def create_dataset(dataset_opt): mode = dataset_opt['mode'] # datasets for image restoration - if mode == 'LQ': - from data.LQ_dataset import LQDataset as D - elif mode == 'LQGT': - from data.LQGT_dataset import LQGTDataset as D - # datasets for image corruption - elif mode == 'downsample': - from data.Downsample_dataset import DownsampleDataset as D - elif mode == 'fullimage': + if mode == 'fullimage': from data.full_image_dataset import FullImageDataset as D elif mode == 'single_image_extensible': from data.single_image_dataset import SingleImageDataset as D diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py deleted file mode 100644 index 14b2a461..00000000 --- a/codes/models/SRGAN_model.py +++ /dev/null @@ -1,1045 +0,0 @@ -import logging -from collections import OrderedDict -import torch -import torch.nn as nn -from torch.nn.parallel import DataParallel, DistributedDataParallel -import models.networks as networks -import models.lr_scheduler as lr_scheduler -from models.base_model import BaseModel -from models.loss import GANLoss, FDPLLoss -from apex import amp -from data.weight_scheduler import get_scheduler_for_opt -from .archs.SPSR_arch import ImageGradient, ImageGradientNoPadding -import torch.nn.functional as F -import glob -import random - -import torchvision.utils as utils -import os - -logger = logging.getLogger('base') - - -class GaussianBlur(nn.Module): - def __init__(self): - super(GaussianBlur, self).__init__() - - # Set these to whatever you want for your gaussian filter - kernel_size = 3 - sigma = 2 - - # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) - x_cord = torch.arange(kernel_size) - x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) - y_grid = x_grid.t() - xy_grid = torch.stack([x_grid, y_grid], dim=-1) - - mean = (kernel_size - 1) / 2. - variance = sigma ** 2. - - # Calculate the 2-dimensional gaussian kernel which is - # the product of two gaussian distributions for two different - # variables (in this case called x and y) - gaussian_kernel = (1. / (2. * 3.1415926 * variance)) * \ - torch.exp( - -torch.sum((xy_grid - mean) ** 2., dim=-1) / \ - (2 * variance) - ) - # Make sure sum of values in gaussian kernel equals 1. - gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) - - # Reshape to 2d depthwise convolutional weight - gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) - gaussian_kernel = gaussian_kernel.repeat(3, 1, 1, 1) - - self.gaussian_filter = nn.Conv2d(in_channels=3, out_channels=3, - kernel_size=kernel_size, groups=3, bias=False) - - self.gaussian_filter.weight.data = gaussian_kernel - self.gaussian_filter.weight.requires_grad = False - - def forward(self, x): - return self.gaussian_filter(x) - - -class SRGANModel(BaseModel): - def __init__(self, opt): - super(SRGANModel, self).__init__(opt) - if opt['dist']: - self.rank = torch.distributed.get_rank() - else: - self.rank = -1 # non dist training - train_opt = opt['train'] - self.spsr_enabled = 'spsr' in opt['model'] - - # define networks and load pretrained models - self.netG = networks.define_G(opt).to(self.device) - if self.is_train: - self.netD = networks.define_D(opt).to(self.device) - if self.spsr_enabled: - logger.info("Defining grad net...") - self.netD_grad = networks.define_D(opt, wrap=True).to(self.device) # D_grad - - if 'network_C' in opt.keys(): - self.netC = networks.define_G(opt, net_key='network_C').to(self.device) - # The corruptor net is fixed. Lock 'her down. - self.netC.eval() - for p in self.netC.parameters(): - p.requires_grad = True - else: - self.netC = None - self.mega_batch_factor = 1 - self.disjoint_data = False - - # define losses, optimizer and scheduler - if self.is_train: - self.mega_batch_factor = train_opt['mega_batch_factor'] - if self.mega_batch_factor is None: - self.mega_batch_factor = 1 - # G pixel loss - if train_opt['pixel_weight'] > 0: - l_pix_type = train_opt['pixel_criterion'] - if l_pix_type == 'l1': - self.cri_pix = nn.L1Loss().to(self.device) - elif l_pix_type == 'l2': - self.cri_pix = nn.MSELoss().to(self.device) - else: - raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) - self.l_pix_w = train_opt['pixel_weight'] - else: - logger.info('Remove pixel loss.') - self.cri_pix = None - - # FDPL loss. - if 'fdpl_loss' in train_opt.keys(): - fdpl_opt = train_opt['fdpl_loss'] - self.fdpl_weight = fdpl_opt['weight'] - self.fdpl_enabled = self.fdpl_weight > 0 - if self.fdpl_enabled: - self.cri_fdpl = FDPLLoss(fdpl_opt['data_mean'], self.device) - else: - self.fdpl_enabled = False - - if self.spsr_enabled: - spsr_opt = train_opt['spsr'] - self.branch_pretrain = spsr_opt['branch_pretrain'] if spsr_opt['branch_pretrain'] else 0 - self.branch_init_iters = spsr_opt['branch_init_iters'] if spsr_opt['branch_init_iters'] else 1 - if spsr_opt['gradient_pixel_weight'] > 0: - self.cri_pix_grad = nn.MSELoss().to(self.device) - self.l_pix_grad_w = spsr_opt['gradient_pixel_weight'] - else: - self.cri_pix_grad = None - if spsr_opt['gradient_gan_weight'] > 0: - self.cri_grad_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) - self.l_gan_grad_w = spsr_opt['gradient_gan_weight'] - else: - self.cri_grad_gan = None - if spsr_opt['pixel_branch_weight'] > 0: - l_pix_type = spsr_opt['pixel_branch_criterion'] - if l_pix_type == 'l1': - self.cri_pix_branch = nn.L1Loss().to(self.device) - elif l_pix_type == 'l2': - self.cri_pix_branch = nn.MSELoss().to(self.device) - else: - raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) - self.l_pix_branch_w = spsr_opt['pixel_branch_weight'] - else: - logger.info('Remove G_grad pixel loss.') - self.cri_pix_branch = None - - # G feature loss - if train_opt['feature_weight'] and train_opt['feature_weight'] > 0: - # For backwards compatibility, use a scheduler definition instead. Remove this at some point. - l_fea_type = train_opt['feature_criterion'] - if l_fea_type == 'l1': - self.cri_fea = nn.L1Loss().to(self.device) - elif l_fea_type == 'l2': - self.cri_fea = nn.MSELoss().to(self.device) - else: - raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) - sched_params = { - 'type': 'fixed', - 'weight': train_opt['feature_weight'] - } - self.l_fea_sched = get_scheduler_for_opt(sched_params) - elif train_opt['feature_scheduler']: - self.l_fea_sched = get_scheduler_for_opt(train_opt['feature_scheduler']) - l_fea_type = train_opt['feature_criterion'] - if l_fea_type == 'l1': - self.cri_fea = nn.L1Loss().to(self.device) - elif l_fea_type == 'l2': - self.cri_fea = nn.MSELoss().to(self.device) - else: - raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) - else: - logger.info('Remove feature loss.') - self.cri_fea = None - if self.cri_fea: # load VGG perceptual loss - self.use_corrupted_feature_input = train_opt['corrupted_feature_input'] if 'corrupted_feature_input' in train_opt.keys() else False - if self.use_corrupted_feature_input: - logger.info("Corrupting inputs into the feature network..") - self.feature_corruptor = GaussianBlur().to(self.device) - self.netF = networks.define_F(use_bn=False).to(self.device) - self.lr_netF = None - if 'lr_fea_path' in train_opt.keys(): - self.lr_netF = networks.define_F(use_bn=False, load_path=train_opt['lr_fea_path']).to(self.device) - self.disjoint_data = True - - if opt['dist']: - pass # do not need to use DistributedDataParallel for netF - else: - self.netF = DataParallel(self.netF) - if self.lr_netF: - self.lr_netF = DataParallel(self.lr_netF) - - # You can feed in a list of frozen pre-trained discriminators. These are treated the same as feature losses. - self.fixed_disc_nets = [] - if 'fixed_discriminators' in opt.keys(): - for opt_fdisc in opt['fixed_discriminators'].keys(): - netFD = networks.define_fixed_D(opt['fixed_discriminators'][opt_fdisc]).to(self.device) - if opt['dist']: - pass # do not need to use DistributedDataParallel for netF - else: - netFD = DataParallel(netFD) - self.fixed_disc_nets.append(netFD) - - # GD gan loss - self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) - self.l_gan_w = train_opt['gan_weight'] - # D_update_ratio and D_init_iters - self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 - self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 - self.G_warmup = train_opt['G_warmup'] if train_opt['G_warmup'] else -1 - self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0 - self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0 - self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0 - self.corruptor_swapout_steps = train_opt['corruptor_swapout_steps'] if train_opt['corruptor_swapout_steps'] else 500 - self.corruptor_usage_prob = train_opt['corruptor_usage_probability'] if train_opt['corruptor_usage_probability'] else .5 - - # optimizers - # G optimizer - wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 - optim_params = [] - if train_opt['lr_scheme'] == 'ProgressiveMultiStepLR': - optim_params = self.netG.get_param_groups() - else: - for k, v in self.netG.named_parameters(): # can optimize for a part of the model - if v.requires_grad: - optim_params.append(v) - else: - if self.rank <= 0: - logger.warning('Params [{:s}] will not optimize.'.format(k)) - self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], - weight_decay=wd_G, - betas=(train_opt['beta1_G'], train_opt['beta2_G'])) - self.optimizers.append(self.optimizer_G) - # D optimizer - optim_params = [] - for k, v in self.netD.named_parameters(): # can optimize for a part of the model - if v.requires_grad: - optim_params.append(v) - else: - if self.rank <= 0: - logger.warning('Params [{:s}] will not optimize.'.format(k)) - wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 - self.optimizer_D = torch.optim.Adam(optim_params, lr=train_opt['lr_D'], - weight_decay=wd_D, - betas=(train_opt['beta1_D'], train_opt['beta2_D'])) - self.optimizers.append(self.optimizer_D) - self.disc_optimizers.append(self.optimizer_D) - - if self.spsr_enabled: - # D_grad optimizer - optim_params = [] - for k, v in self.netD_grad.named_parameters(): # can optimize for a part of the model - if v.requires_grad: - optim_params.append(v) - else: - if self.rank <= 0: - logger.warning('Params [{:s}] will not optimize.'.format(k)) - # D - wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 - self.optimizer_D_grad = torch.optim.Adam(optim_params, lr=train_opt['lr_D'], - weight_decay=wd_D, - betas=(train_opt['beta1_D'], train_opt['beta2_D'])) - self.optimizers.append(self.optimizer_D_grad) - self.disc_optimizers.append(self.optimizer_D_grad) - - if self.spsr_enabled: - self.get_grad_nopadding = ImageGradientNoPadding().to(self.device) - [self.netG, self.netD, self.netD_grad, self.get_grad_nopadding], \ - [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \ - amp.initialize([self.netG, self.netD, self.netD_grad, self.get_grad_nopadding], - [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad], - opt_level=self.amp_level, num_losses=3) - else: - # AMP - [self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \ - amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3) - - # DataParallel - if opt['dist']: - self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()], - find_unused_parameters=True) - else: - self.netG = DataParallel(self.netG) - if self.is_train: - if opt['dist']: - self.netD = DistributedDataParallel(self.netD, - device_ids=[torch.cuda.current_device()], - find_unused_parameters=True) - if self.spsr_enabled: - self.netD_grad = DistributedDataParallel(self.netD_grad, - device_ids=[torch.cuda.current_device()], - find_unused_parameters=True) - else: - self.netD = DataParallel(self.netD) - if self.spsr_enabled: - self.netD_grad = DataParallel(self.netD_grad) - self.netG.train() - self.netD.train() - if self.spsr_enabled: - self.netD_grad.train() - - # schedulers - if train_opt['lr_scheme'] == 'MultiStepLR': - # This is a recent change. assert to make sure any legacy configs dont find their way here. - assert 'gen_lr_steps' in train_opt.keys() and 'disc_lr_steps' in train_opt.keys() - self.schedulers.append( - lr_scheduler.MultiStepLR_Restart(self.optimizer_G, train_opt['gen_lr_steps'], - restarts=train_opt['restarts'], - weights=train_opt['restart_weights'], - gamma=train_opt['lr_gamma'], - clear_state=train_opt['clear_state'], - force_lr=train_opt['force_lr'])) - for o in self.disc_optimizers: - self.schedulers.append( - lr_scheduler.MultiStepLR_Restart(o, train_opt['disc_lr_steps'], - restarts=train_opt['restarts'], - weights=train_opt['restart_weights'], - gamma=train_opt['lr_gamma'], - clear_state=train_opt['clear_state'], - force_lr=train_opt['force_lr'])) - elif train_opt['lr_scheme'] == 'ProgressiveMultiStepLR': - # Only supported when there are two optimizers: G and D. - assert len(self.optimizers) == 2 - self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_G, train_opt['gen_lr_steps'], - self.netG.module.get_progressive_starts(), - train_opt['lr_gamma'])) - for o in self.disc_optimizers: - self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(o, train_opt['disc_lr_steps'], - [0], - train_opt['lr_gamma'])) - elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': - for optimizer in self.optimizers: - self.schedulers.append( - lr_scheduler.CosineAnnealingLR_Restart( - optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], - restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) - else: - raise NotImplementedError('MultiStepLR learning rate scheme is enough.') - - self.log_dict = OrderedDict() - - # Swapout params - self.swapout_G_freq = train_opt['swapout_G_freq'] if train_opt['swapout_G_freq'] else 0 - self.swapout_G_duration = 0 - self.swapout_D_freq = train_opt['swapout_D_freq'] if train_opt['swapout_D_freq'] else 0 - self.swapout_D_duration = 0 - self.swapout_duration = train_opt['swapout_duration'] if train_opt['swapout_duration'] else 0 - - # GAN LQ image params - self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0 - - self.img_debug_steps = opt['logger']['img_debug_steps'] if 'img_debug_steps' in opt['logger'].keys() else 50 - else: - self.netF = networks.define_F(use_bn=False).to(self.device) - self.cri_fea = nn.L1Loss().to(self.device) - - #self.print_network() # print network - self.load() # load G and D if needed - self.load_random_corruptor() - - # Setting this to false triggers SRGAN to call the models update_model() function on the first iteration. - self.updated = True - - def feed_data(self, data, need_GT=True): - _profile = True - if _profile: - from time import time - _t = time() - - # Corrupt the data with the given corruptor, if specified. - self.fed_LQ = data['LQ'].to(self.device) - if self.netC and random.random() < self.corruptor_usage_prob: - with torch.no_grad(): - corrupted_L = self.netC(self.fed_LQ)[0].detach() - else: - corrupted_L = self.fed_LQ - - self.var_L = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0) - if need_GT: - self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] - input_ref = data['ref'] if 'ref' in data.keys() else data['GT'] - self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] - input_pix = data['PIX'] if 'pix' in data.keys() else data['GT'] - self.pix = [t.to(self.device) for t in torch.chunk(input_pix, chunks=self.mega_batch_factor, dim=0)] - - if 'GAN' in data.keys(): - self.gan_img = [t.to(self.device) for t in torch.chunk(data['GAN'], chunks=self.mega_batch_factor, dim=0)] - else: - # If not provided, use provided LQ for anyplace where the GAN would have been used. - self.gan_img = self.var_L - - if not self.updated: - self.netG.module.update_model(self.optimizer_G, self.schedulers[0]) - self.updated = True - - def optimize_parameters(self, step): - _profile = False - if _profile: - from time import time - _t = time() - - # Some generators have variants depending on the current step. - if hasattr(self.netG.module, "update_for_step"): - self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) - if hasattr(self.netD.module, "update_for_step"): - self.netD.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) - - # G - for p in self.netD.parameters(): - p.requires_grad = False - if self.spsr_enabled: - for p in self.netD_grad.parameters(): - p.requires_grad = False - - self.swapout_D(step) - self.swapout_G(step) - - # Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason. - if step % self.D_update_ratio == 0 and step >= self.D_init_iters: - if self.spsr_enabled and self.branch_pretrain and step < self.branch_init_iters: - for k, v in self.netG.named_parameters(): - if v.dtype != torch.int64 and v.dtype != torch.bool: - v.requires_grad = '_branch_pretrain' in k - else: - for p in self.netG.parameters(): - if p.dtype != torch.int64 and p.dtype != torch.bool: - p.requires_grad = True - else: - for p in self.netG.parameters(): - p.requires_grad = False - - # Calculate a standard deviation for the gaussian noise to be applied to the discriminator, termed noise-theta. - if self.D_noise_final == 0: - noise_theta = 0 - else: - noise_theta = (self.D_noise_theta - self.D_noise_theta_floor) * (self.D_noise_final - min(step, self.D_noise_final)) / self.D_noise_final + self.D_noise_theta_floor - - if _profile: - print("Misc setup %f" % (time() - _t,)) - _t = time() - - self.optimizer_G.zero_grad() - self.fake_GenOut = [] - self.fea_GenOut = [] - self.fake_H = [] - self.spsr_grad_GenOut = [] - var_ref_skips = [] - for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix): - if self.spsr_enabled: - using_gan_img = False - # SPSR models have outputs from three different branches. - fake_H_branch, fake_GenOut, grad_LR = self.netG(var_L) - fea_GenOut = fake_GenOut - self.spsr_grad_GenOut.append(fake_H_branch) - # Get image gradients for later use. - fake_H_grad = self.get_grad_nopadding(fake_GenOut) - else: - if random.random() > self.gan_lq_img_use_prob: - fea_GenOut, fake_GenOut = self.netG(var_L) - using_gan_img = False - else: - fea_GenOut, fake_GenOut = self.netG(var_LGAN) - using_gan_img = True - - if _profile: - print("Gen forward %f" % (time() - _t,)) - _t = time() - - self.fake_GenOut.append(fake_GenOut.detach()) - self.fea_GenOut.append(fea_GenOut.detach()) - - l_g_total = 0 - if step % self.D_update_ratio == 0 and step >= self.D_init_iters: - fea_w = self.l_fea_sched.get_weight_for_step(step) - l_g_pix_log = None - l_g_fea_log = None - l_g_fdpl = None - l_g_fea_log = None - if self.cri_pix and not using_gan_img: # pixel loss - l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix) - l_g_pix_log = l_g_pix / self.l_pix_w - l_g_total += l_g_pix - if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss - if self.disjoint_data: - grad_truth = self.get_grad_nopadding(var_L) - grad_pred = F.interpolate(fake_H_grad, size=grad_truth.shape[2:], mode="nearest") - else: - grad_truth = self.get_grad_nopadding(var_H) - grad_pred = fake_H_grad - l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(grad_pred, grad_truth) - l_g_total += l_g_pix_grad - if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss - if self.disjoint_data: - grad_truth = self.get_grad_nopadding(var_L) - grad_pred = F.interpolate(fake_H_branch, size=grad_truth.shape[2:], mode="nearest") - else: - grad_truth = self.get_grad_nopadding(var_H) - grad_pred = fake_H_branch - l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(grad_pred, grad_truth) - l_g_total += l_g_pix_grad_branch - if self.fdpl_enabled and not using_gan_img: - l_g_fdpl = self.cri_fdpl(fea_GenOut, pix) - l_g_total += l_g_fdpl * self.fdpl_weight - if self.cri_fea and not using_gan_img and fea_w > 0: # feature loss - if self.lr_netF is not None: - real_fea = self.lr_netF(var_L, interpolate_factor=self.opt['scale']) - elif self.use_corrupted_feature_input: - cor_Pix = F.interpolate(self.feature_corruptor(pix), size=var_L.shape[2:]) - real_fea = self.netF(cor_Pix).detach() - else: - real_fea = self.netF(pix).detach() - if self.use_corrupted_feature_input: - fake_fea = self.netF(F.interpolate(self.feature_corruptor(fea_GenOut), size=var_L.shape[2:])) - else: - fake_fea = self.netF(fea_GenOut) - l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea) - l_g_fea_log = l_g_fea / fea_w - l_g_total += l_g_fea - - if _profile: - print("Fea forward %f" % (time() - _t,)) - _t = time() - - # Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931 - # Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is - # equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically, - # it should target this - - l_g_fix_disc = torch.zeros(1, requires_grad=False, device=self.device).squeeze() - for fixed_disc in self.fixed_disc_nets: - weight = fixed_disc.module.fdisc_weight - real_fea = fixed_disc(pix).detach() - fake_fea = fixed_disc(fea_GenOut) - l_g_fix_disc = l_g_fix_disc + weight * self.cri_fea(fake_fea, real_fea) - l_g_total += l_g_fix_disc - - - if self.l_gan_w > 0: - if self.opt['train']['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: - if self.opt['train']['gan_type'] == 'crossgan': - pred_g_fake = self.netD(fake_GenOut, var_L) - else: - pred_g_fake = self.netD(fake_GenOut) - l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) - elif self.opt['train']['gan_type'] == 'ragan': - pred_d_real = self.netD(var_ref).detach() - pred_g_fake = self.netD(fake_GenOut) - l_g_gan = self.l_gan_w * ( - self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + - self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 - l_g_gan_log = l_g_gan / self.l_gan_w - l_g_total += l_g_gan - - if self.spsr_enabled and self.cri_grad_gan: - if self.opt['train']['gan_type'] == 'crossgan': - pred_g_fake_grad = self.netD_grad(fake_H_grad, var_L) - pred_g_fake_grad_branch = self.netD_grad(fake_H_branch, var_L) - else: - pred_g_fake_grad = self.netD_grad(fake_H_grad) - pred_g_fake_grad_branch = self.netD_grad(fake_H_branch) - if self.opt['train']['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: - l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True) - l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True) - elif self.opt['train']['gan_type'] == 'ragan': - pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach() - l_g_gan_grad = self.l_gan_grad_w * ( - self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) + - self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2 - l_g_gan_grad_branch = self.l_gan_grad_w * ( - self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad_branch), False) + - self.cri_gan(pred_g_fake_grad_branch - torch.mean(pred_g_real_grad), True)) / 2 - l_g_total += l_g_gan_grad + l_g_gan_grad_branch - - # Scale the loss down by the batch factor. - l_g_total_log = l_g_total - l_g_total = l_g_total / self.mega_batch_factor - - with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled: - l_g_total_scaled.backward() - - if _profile: - print("Gen backward %f" % (time() - _t,)) - _t = time() - - self.optimizer_G.step() - - if _profile: - print("Gen step %f" % (time() - _t,)) - _t = time() - - # D - if self.l_gan_w > 0 and step >= self.G_warmup: - for p in self.netD.parameters(): - if p.dtype != torch.int64 and p.dtype != torch.bool: - p.requires_grad = True - - noise = torch.randn_like(var_ref) * noise_theta - noise.to(self.device) - real_disc_images = [] - fake_disc_images = [] - for fake_GenOut, var_LGAN, var_L, var_H, var_ref, pix in zip(self.fake_GenOut, self.gan_img, self.var_L, self.var_H, self.var_ref, self.pix): - if random.random() > self.gan_lq_img_use_prob: - fake_H = fake_GenOut.clone().detach().requires_grad_(False) - else: - # Re-compute generator outputs with the GAN inputs. - with torch.no_grad(): - if self.spsr_enabled: - _, fake_H, _ = self.netG(var_LGAN) - else: - _, fake_H = self.netG(var_LGAN) - fake_H = fake_H.detach() - - if _profile: - print("Gen forward for disc %f" % (time() - _t,)) - _t = time() - - # Apply noise to the inputs to slow discriminator convergence. - var_ref = var_ref + noise - fake_H = fake_H + noise - l_d_fea_real = 0 - l_d_fea_fake = 0 - self.optimizer_D.zero_grad() - if self.opt['train']['gan_type'] == 'pixgan_fea': - # Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better. - disc_fea_scale = .1 - _, fea_real = self.netD(var_ref, output_feature_vector=True) - actual_fea = self.netF(var_ref) - l_d_fea_real = self.cri_fea(fea_real, actual_fea) * disc_fea_scale / self.mega_batch_factor - _, fea_fake = self.netD(fake_H, output_feature_vector=True) - actual_fea = self.netF(fake_H) - l_d_fea_fake = self.cri_fea(fea_fake, actual_fea) * disc_fea_scale / self.mega_batch_factor - if self.opt['train']['gan_type'] == 'crossgan': - # need to forward and backward separately, since batch norm statistics differ - # real - pred_d_real = self.netD(var_ref, var_L) - l_d_real = self.cri_gan(pred_d_real, True) - l_d_real_log = l_d_real - # fake - pred_d_fake = self.netD(fake_H, var_L) - l_d_fake = self.cri_gan(pred_d_fake, False) - l_d_fake_log = l_d_fake - # mismatched - mismatched_L = torch.roll(var_L, shifts=1, dims=0) - pred_d_real_mismatched = self.netD(var_ref, mismatched_L) - pred_d_fake_mismatched = self.netD(fake_H, mismatched_L) - l_d_mismatched = (self.cri_gan(pred_d_real_mismatched, False) + self.cri_gan(pred_d_fake_mismatched, False)) / 2 - - l_d_total = (l_d_real + l_d_fake + l_d_mismatched) / 3 - l_d_total = l_d_total / self.mega_batch_factor - with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: - l_d_total_scaled.backward() - elif self.opt['train']['gan_type'] == 'gan': - # real - pred_d_real = self.netD(var_ref) - l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor - l_d_real_log = l_d_real * self.mega_batch_factor - # fake - pred_d_fake = self.netD(fake_H) - l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor - l_d_fake_log = l_d_fake * self.mega_batch_factor - - l_d_total = (l_d_real + l_d_fake) / 2 - with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: - l_d_total_scaled.backward() - elif 'pixgan' in self.opt['train']['gan_type']: - pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() - disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) - b, _, w, h = var_ref.shape - real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device) - fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device) - if not self.disjoint_data: - # randomly determine portions of the image to swap to keep the discriminator honest. - SWAP_MAX_DIM = w // 4 - SWAP_MIN_DIM = 16 - assert SWAP_MAX_DIM > 0 - if random.random() > .5: # Make this only happen half the time. Earlier experiments had it happen - # more often and the model was "cheating" by using the presence of - # easily discriminated fake swaps to count the entire generated image - # as fake. - random_swap_count = random.randint(0, 4) - for i in range(random_swap_count): - # Make the swap across fake_H and var_ref - swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM) - swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM) - if swap_x + swap_w > w: - swap_w = w - swap_x - if swap_y + swap_h > h: - swap_h = h - swap_y - t = fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() - fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] - var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t - real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0 - fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0 - - # Interpolate down to the dimensionality that the discriminator uses. - real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear", align_corners=False) - fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear", align_corners=False) - - # We're also assuming that this is exactly how the flattened discriminator output is generated. - real = real.view(-1, 1) - fake = fake.view(-1, 1) - - # real - pred_d_real = self.netD(var_ref) - l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor - l_d_real_log = l_d_real * self.mega_batch_factor - l_d_real += l_d_fea_real - # fake - pred_d_fake = self.netD(fake_H) - l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor - l_d_fake_log = l_d_fake * self.mega_batch_factor - l_d_fake += l_d_fea_fake - - l_d_total = (l_d_real + l_d_fake) / 2 - with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: - l_d_total_scaled.backward() - - pdr = pred_d_real.detach() + torch.abs(torch.min(pred_d_real)) - pdr = pdr / torch.max(pdr) - real_disc_images.append(pdr.view(disc_output_shape)) - pdf = pred_d_fake.detach() + torch.abs(torch.min(pred_d_fake)) - pdf = pdf / torch.max(pdf) - fake_disc_images.append(pdf.view(disc_output_shape)) - elif self.opt['train']['gan_type'] == 'ragan': - pred_d_fake = self.netD(fake_H) - pred_d_real = self.netD(var_ref) - l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) - l_d_real_log = l_d_real - l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) - l_d_fake_log = l_d_fake - l_d_total = (l_d_real + l_d_fake) / 2 - l_d_total /= self.mega_batch_factor - with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: - l_d_total_scaled.backward() - var_ref_skips.append(var_ref.detach()) - self.fake_H.append(fake_H.detach()) - self.optimizer_D.step() - - if _profile: - print("Disc step %f" % (time() - _t,)) - _t = time() - - # D_grad. - if self.spsr_enabled and self.cri_grad_gan and step >= self.G_warmup: - for p in self.netD_grad.parameters(): - p.requires_grad = True - self.optimizer_D_grad.zero_grad() - for var_L, var_ref, fake_H, fake_H_grad_branch in zip(self.var_L, var_ref_skips, self.fake_H, self.spsr_grad_GenOut): - fake_H_grad = self.get_grad_nopadding(fake_H).detach() - var_ref_grad = self.get_grad_nopadding(var_ref) - fake_H_grad_branch = fake_H_grad_branch.detach() + noise - if self.opt['train']['gan_type'] == 'crossgan': - pred_d_real_grad = self.netD_grad(var_ref_grad, var_L) - pred_d_fake_grad = self.netD_grad(fake_H_grad, var_L) # Tensor already detached above. - # var_ref and fake_H already has noise added to it. We **must** add noise to fake_H_grad_branch too. - pred_d_fake_grad_branch = self.netD_grad(fake_H_grad_branch, var_L) - else: - pred_d_real_grad = self.netD_grad(var_ref_grad) - pred_d_fake_grad = self.netD_grad(fake_H_grad) # Tensor already detached above. - # var_ref and fake_H already has noise added to it. We **must** add noise to fake_H_grad_branch too. - pred_d_fake_grad_branch = self.netD_grad(fake_H_grad_branch) - if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'crossgan': - l_d_real_grad = self.cri_gan(pred_d_real_grad, True) - l_d_fake_grad = (self.cri_gan(pred_d_fake_grad, False) + self.cri_gan(pred_d_fake_grad_branch, False)) / 2 - elif self.opt['train']['gan_type'] == 'pixgan': - real = torch.ones_like(pred_d_real_grad) - fake = torch.zeros_like(pred_d_fake_grad) - l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real) - l_d_fake_grad = (self.cri_grad_gan(pred_d_fake_grad, fake) + \ - self.cri_grad_gan(pred_d_fake_grad_branch, fake)) / 2 - elif self.opt['train']['gan_type'] == 'ragan': - l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True) - l_d_fake_grad = (self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) + \ - self.cri_grad_gan(pred_d_fake_grad_branch - torch.mean(pred_d_real_grad), False)) / 2 - - l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2 - l_d_total_grad /= self.mega_batch_factor - with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled: - l_d_total_grad_scaled.backward() - self.optimizer_D_grad.step() - - - # Log sample images from first microbatch. - if step % self.img_debug_steps == 0 and self.rank <= 0: - sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp") - os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True) - os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True) - os.makedirs(os.path.join(sample_save_path, "gen_fea"), exist_ok=True) - os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True) - os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True) - os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True) - os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True) - os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True) - if self.spsr_enabled: - os.makedirs(os.path.join(sample_save_path, "gen_grad"), exist_ok=True) - - # fed_LQ is not chunked. - for i in range(self.mega_batch_factor): - utils.save_image(self.var_H[i].cpu(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i))) - utils.save_image(self.var_L[i].cpu(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i))) - utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i))) - utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) - utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i))) - if self.spsr_enabled: - utils.save_image(self.spsr_grad_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_grad", "%05i_%02i.png" % (step, i))) - if self.l_gan_w > 0 and step >= self.G_warmup and 'pixgan' in self.opt['train']['gan_type']: - utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i))) - utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i))) - utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i))) - utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i))) - - # Log metrics - if step % self.D_update_ratio == 0 and step >= self.D_init_iters: - if self.cri_pix and l_g_pix_log is not None: - self.add_log_entry('l_g_pix', l_g_pix_log.detach().item()) - if self.fdpl_enabled and l_g_fdpl is not None: - self.add_log_entry('l_g_fdpl', l_g_fdpl.detach().item()) - if self.cri_fea and l_g_fea_log is not None: - self.add_log_entry('feature_weight', fea_w) - self.add_log_entry('l_g_fea', l_g_fea_log.detach().item()) - self.add_log_entry('l_g_fix_disc', l_g_fix_disc.detach().item()) - if self.l_gan_w > 0: - self.add_log_entry('l_g_gan', l_g_gan_log.detach().item()) - self.add_log_entry('l_g_total', l_g_total_log.detach().item()) - if self.opt['train']['gan_type'] == 'pixgan_fea': - self.add_log_entry('l_d_fea_fake', l_d_fea_fake.detach().item() * self.mega_batch_factor) - self.add_log_entry('l_d_fea_real', l_d_fea_real.detach().item() * self.mega_batch_factor) - self.add_log_entry('l_d_fake_total', l_d_fake.detach().item() * self.mega_batch_factor) - self.add_log_entry('l_d_real_total', l_d_real.detach().item() * self.mega_batch_factor) - if self.opt['train']['gan_type'] == 'crossgan': - self.add_log_entry('l_d_mismatched', l_d_mismatched.detach().item()) - if self.spsr_enabled: - if self.cri_pix_grad: - self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad.detach().item()) - if self.cri_pix_branch: - self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.detach().item()) - if self.cri_grad_gan: - self.add_log_entry('l_g_gan_grad', l_g_gan_grad.detach().item() / self.l_gan_grad_w) - self.add_log_entry('l_g_gan_grad_branch', l_g_gan_grad_branch.detach().item() / self.l_gan_grad_w) - if self.l_gan_w > 0 and step >= self.G_warmup: - self.add_log_entry('l_d_real', l_d_real_log.detach().item()) - self.add_log_entry('l_d_fake', l_d_fake_log.detach().item()) - self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) - self.add_log_entry('D_diff', torch.mean(pred_d_fake.detach()) - torch.mean(pred_d_real.detach())) - if self.spsr_enabled and self.cri_grad_gan: - self.add_log_entry('l_d_real_grad', l_d_real_grad.detach().item()) - self.add_log_entry('l_d_fake_grad', l_d_fake_grad.detach().item()) - self.add_log_entry('D_fake_grad', torch.mean(pred_d_fake_grad.detach())) - self.add_log_entry('D_diff_grad', torch.mean(pred_d_fake_grad.detach()) - torch.mean(pred_d_real_grad.detach())) - - # Log learning rates. - for i, pg in enumerate(self.optimizer_G.param_groups): - self.add_log_entry('gen_lr_%i' % (i,), pg['lr']) - for i, pg in enumerate(self.optimizer_D.param_groups): - self.add_log_entry('disc_lr_%i' % (i,), pg['lr']) - - if step % self.corruptor_swapout_steps == 0 and step > 0: - self.load_random_corruptor() - - # Allows the log to serve as an easy-to-use rotating buffer. - def add_log_entry(self, key, value): - key_it = "%s_it" % (key,) - log_rotating_buffer_size = 50 - if key not in self.log_dict.keys(): - self.log_dict[key] = [] - self.log_dict[key_it] = 0 - if len(self.log_dict[key]) < log_rotating_buffer_size: - self.log_dict[key].append(value) - else: - self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value - self.log_dict[key_it] += 1 - - def pick_rand_prev_model(self, model_suffix): - previous_models = glob.glob(os.path.join(self.opt['path']['models'], "*_%s.pth" % (model_suffix,))) - if len(previous_models) <= 1: - return None - # Just a note: this intentionally includes the swap model in the list of possibilities. - return previous_models[random.randint(0, len(previous_models)-1)] - - def compute_fea_loss(self, real, fake): - with torch.no_grad(): - real = real.unsqueeze(dim=0).to(self.device) - fake = fake.unsqueeze(dim=0).to(self.device) - real_fea = self.netF(real) - fake_fea = self.netF(fake) - return self.cri_fea(fake_fea, real_fea).item() - - # Called before verification/checkpoint to ensure we're using the real models and not a swapout variant. - def force_restore_swapout(self): - if self.swapout_D_duration > 0: - logger.info("Swapping back to current D model: %s" % (self.stashed_D,)) - self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load']) - self.stashed_D = None - self.swapout_D_duration = 0 - if self.swapout_G_duration > 0: - logger.info("Swapping back to current G model: %s" % (self.stashed_G,)) - self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load']) - self.stashed_G = None - self.swapout_G_duration = 0 - - def swapout_D(self, step): - if self.swapout_D_duration > 0: - self.swapout_D_duration -= 1 - if self.swapout_D_duration == 0: - # Swap back. - logger.info("Swapping back to current D model: %s" % (self.stashed_D,)) - self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load']) - self.stashed_D = None - elif self.swapout_D_freq != 0 and step % self.swapout_D_freq == 0: - swapped_model = self.pick_rand_prev_model('D') - if swapped_model is not None: - logger.info("Swapping to previous D model: %s" % (swapped_model,)) - self.stashed_D = self.save_network(self.netD, 'D', 'swap_model') - self.load_network(swapped_model, self.netD, self.opt['path']['strict_load']) - self.swapout_D_duration = self.swapout_duration - - def swapout_G(self, step): - if self.swapout_G_duration > 0: - self.swapout_G_duration -= 1 - if self.swapout_G_duration == 0: - # Swap back. - logger.info("Swapping back to current G model: %s" % (self.stashed_G,)) - self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load']) - self.stashed_G = None - elif self.swapout_G_freq != 0 and step % self.swapout_G_freq == 0: - swapped_model = self.pick_rand_prev_model('G') - if swapped_model is not None: - logger.info("Swapping to previous G model: %s" % (swapped_model,)) - self.stashed_G = self.save_network(self.netG, 'G', 'swap_model') - self.load_network(swapped_model, self.netG, self.opt['path']['strict_load']) - self.swapout_G_duration = self.swapout_duration - - def test(self): - self.netG.eval() - with torch.no_grad(): - if self.spsr_enabled: - self.fake_H_branch = [] - self.fake_GenOut = [] - self.grad_LR = [] - fake_H_branch, fake_GenOut, grad_LR = self.netG(self.var_L[0]) - self.fake_H_branch.append(fake_H_branch) - self.fake_GenOut.append(fake_GenOut) - self.grad_LR.append(grad_LR) - else: - self.fake_GenOut = [self.netG(self.var_L[0])] - self.netG.train() - - # Fetches a summary of the log. - def get_current_log(self, step): - return_log = {} - for k in self.log_dict.keys(): - if not isinstance(self.log_dict[k], list): - continue - return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k]) - - # Some generators can do their own metric logging. - if hasattr(self.netG.module, "get_debug_values"): - return_log.update(self.netG.module.get_debug_values(step)) - if hasattr(self.netD.module, "get_debug_values"): - return_log.update(self.netD.module.get_debug_values(step)) - - return return_log - - def get_current_visuals(self, need_GT=True): - out_dict = OrderedDict() - out_dict['LQ'] = self.var_L[0].detach().float().cpu() - gen_batch = self.fake_GenOut[0] - if isinstance(gen_batch, tuple): - gen_batch = gen_batch[0] - out_dict['rlt'] = gen_batch.detach().float().cpu() - if need_GT: - out_dict['GT'] = self.var_H[0].detach().float().cpu() - if self.spsr_enabled: - out_dict['SR_branch'] = self.fake_H_branch[0].float().cpu() - out_dict['LR_grad'] = self.grad_LR[0].float().cpu() - return out_dict - - def print_network(self): - # Generator - s, n = self.get_network_description(self.netG) - if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): - net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, - self.netG.module.__class__.__name__) - else: - net_struc_str = '{}'.format(self.netG.__class__.__name__) - if self.rank <= 0: - logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) - logger.info(s) - if self.is_train: - # Discriminator - s, n = self.get_network_description(self.netD) - if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, - DistributedDataParallel): - net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, - self.netD.module.__class__.__name__) - else: - net_struc_str = '{}'.format(self.netD.__class__.__name__) - if self.rank <= 0: - logger.info('Network D structure: {}, with parameters: {:,d}'.format( - net_struc_str, n)) - logger.info(s) - - if self.cri_fea: # F, Perceptual Network - s, n = self.get_network_description(self.netF) - if isinstance(self.netF, nn.DataParallel) or isinstance( - self.netF, DistributedDataParallel): - net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, - self.netF.module.__class__.__name__) - else: - net_struc_str = '{}'.format(self.netF.__class__.__name__) - if self.rank <= 0: - logger.info('Network F structure: {}, with parameters: {:,d}'.format( - net_struc_str, n)) - logger.info(s) - - def load(self): - load_path_G = self.opt['path']['pretrain_model_G'] - if load_path_G is not None: - logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) - self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) - load_path_D = self.opt['path']['pretrain_model_D'] - if self.opt['is_train'] and load_path_D is not None: - logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) - self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) - if self.spsr_enabled: - load_path_D_grad = self.opt['path']['pretrain_model_D_grad'] - if self.opt['is_train'] and load_path_D_grad is not None: - logger.info('Loading pretrained model for D_grad [{:s}] ...'.format(load_path_D_grad)) - self.load_network(load_path_D_grad, self.netD_grad) - - def load_random_corruptor(self): - if self.netC is None: - return - corruptor_files = glob.glob(os.path.join(self.opt['path']['pretrained_corruptors_dir'], "*.pth")) - corruptor_to_load = corruptor_files[random.randint(0, len(corruptor_files)-1)] - logger.info('Swapping corruptor to: %s' % (corruptor_to_load,)) - self.load_network(corruptor_to_load, self.netC, self.opt['path']['strict_load']) - - def save(self, iter_step): - self.save_network(self.netG, 'G', iter_step) - self.save_network(self.netD, 'D', iter_step) - if self.spsr_enabled: - self.save_network(self.netD_grad, 'D_grad', iter_step) diff --git a/codes/models/SR_model.py b/codes/models/SR_model.py deleted file mode 100644 index 02fc7a7c..00000000 --- a/codes/models/SR_model.py +++ /dev/null @@ -1,171 +0,0 @@ -import logging -from collections import OrderedDict - -import torch -import torch.nn as nn -from torch.nn.parallel import DataParallel, DistributedDataParallel -import models.networks as networks -import models.lr_scheduler as lr_scheduler -from .base_model import BaseModel -from models.loss import CharbonnierLoss -from apex import amp - -logger = logging.getLogger('base') - - -class SRModel(BaseModel): - def __init__(self, opt): - super(SRModel, self).__init__(opt) - - if opt['dist']: - self.rank = torch.distributed.get_rank() - else: - self.rank = -1 # non dist training - train_opt = opt['train'] - - # define network and load pretrained models - self.netG = amp.initialize(networks.define_G(opt).to(self.device), opt_level=self.amp_level) - if opt['dist']: - self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) - elif opt['gpu_ids'] is not None: - self.netG = DataParallel(self.netG) - # print network - self.print_network() - self.load() - - if self.is_train: - self.netG.train() - - # loss - loss_type = train_opt['pixel_criterion'] - if loss_type == 'l1': - self.cri_pix = nn.L1Loss().to(self.device) - elif loss_type == 'l2': - self.cri_pix = nn.MSELoss().to(self.device) - elif loss_type == 'cb': - self.cri_pix = CharbonnierLoss().to(self.device) - else: - raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) - self.l_pix_w = train_opt['pixel_weight'] - - # optimizers - wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 - optim_params = [] - for k, v in self.netG.named_parameters(): # can optimize for a part of the model - if v.requires_grad: - optim_params.append(v) - else: - if self.rank <= 0: - logger.warning('Params [{:s}] will not optimize.'.format(k)) - self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], - weight_decay=wd_G, - betas=(train_opt['beta1'], train_opt['beta2'])) - self.optimizers.append(self.optimizer_G) - - # schedulers - if train_opt['lr_scheme'] == 'MultiStepLR': - for optimizer in self.optimizers: - self.schedulers.append( - lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], - restarts=train_opt['restarts'], - weights=train_opt['restart_weights'], - gamma=train_opt['lr_gamma'], - clear_state=train_opt['clear_state'])) - elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': - for optimizer in self.optimizers: - self.schedulers.append( - lr_scheduler.CosineAnnealingLR_Restart( - optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], - restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) - else: - raise NotImplementedError('MultiStepLR learning rate scheme is enough.') - - self.log_dict = OrderedDict() - - def feed_data(self, data, need_GT=True): - self.var_L = data['LQ'].to(self.device) # LQ - if need_GT: - self.real_H = data['GT'].to(self.device) # GT - - def optimize_parameters(self, step): - self.optimizer_G.zero_grad() - self.fake_H = self.netG(self.var_L) - l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) - l_pix.backward() - self.optimizer_G.step() - - # set log - self.log_dict['l_pix'] = l_pix.item() - - def test(self): - self.netG.eval() - with torch.no_grad(): - self.fake_H = self.netG(self.var_L) - self.netG.train() - - def test_x8(self): - # from https://github.com/thstkdgus35/EDSR-PyTorch - self.netG.eval() - - def _transform(v, op): - # if self.precision != 'single': v = v.float() - v2np = v.data.cpu().numpy() - if op == 'v': - tfnp = v2np[:, :, :, ::-1].copy() - elif op == 'h': - tfnp = v2np[:, :, ::-1, :].copy() - elif op == 't': - tfnp = v2np.transpose((0, 1, 3, 2)).copy() - - ret = torch.Tensor(tfnp).to(self.device) - # if self.precision == 'half': ret = ret.half() - - return ret - - lr_list = [self.var_L] - for tf in 'v', 'h', 't': - lr_list.extend([_transform(t, tf) for t in lr_list]) - with torch.no_grad(): - sr_list = [self.netG(aug) for aug in lr_list] - for i in range(len(sr_list)): - if i > 3: - sr_list[i] = _transform(sr_list[i], 't') - if i % 4 > 1: - sr_list[i] = _transform(sr_list[i], 'h') - if (i % 4) % 2 == 1: - sr_list[i] = _transform(sr_list[i], 'v') - - output_cat = torch.cat(sr_list, dim=0) - self.fake_H = output_cat.mean(dim=0, keepdim=True) - self.netG.train() - - def get_current_log(self): - return self.log_dict - - def get_current_visuals(self, need_GT=True): - out_dict = OrderedDict() - out_dict['LQ'] = self.var_L.detach().float().cpu() - out_dict['rlt'] = self.fake_H.detach().float().cpu() - if need_GT: - out_dict['GT'] = self.real_H.detach().float().cpu() - return out_dict - - def print_network(self): - s, n = self.get_network_description(self.netG) - if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): - net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, - self.netG.module.__class__.__name__) - else: - net_struc_str = '{}'.format(self.netG.__class__.__name__) - if self.rank <= 0: - logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) - logger.info(s) - - def load(self): - load_path_G = self.opt['path']['pretrain_model_G'] - if load_path_G is not None: - logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) - self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) - - def save(self, iter_label): - self.save_network(self.netG, 'G', iter_label) diff --git a/codes/models/__init__.py b/codes/models/__init__.py deleted file mode 100644 index 26f0e1fa..00000000 --- a/codes/models/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -import logging -logger = logging.getLogger('base') - - -def create_model(opt): - model = opt['model'] - # image restoration - if model == 'sr': # PSNR-oriented super resolution - from .SR_model import SRModel as M - elif model == 'srgan' or model == 'corruptgan' or model == 'spsrgan': - from .SRGAN_model import SRGANModel as M - elif model == 'feat': - from .feature_model import FeatureModel as M - elif model == 'spsr': - from .SPSR_model import SPSRModel as M - elif model == 'extensibletrainer': - from .ExtensibleTrainer import ExtensibleTrainer as M - else: - raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) - m = M(opt) - logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) - return m diff --git a/codes/models/layers/channelnorm_package/__init__.py b/codes/models/layers/channelnorm_package/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/models/layers/channelnorm_package/channelnorm.py b/codes/models/layers/channelnorm_package/channelnorm.py deleted file mode 100644 index 76301af5..00000000 --- a/codes/models/layers/channelnorm_package/channelnorm.py +++ /dev/null @@ -1,39 +0,0 @@ -from torch.autograd import Function, Variable -from torch.nn.modules.module import Module -import channelnorm_cuda - -class ChannelNormFunction(Function): - - @staticmethod - def forward(ctx, input1, norm_deg=2): - assert input1.is_contiguous() - b, _, h, w = input1.size() - output = input1.new(b, 1, h, w).zero_() - - channelnorm_cuda.forward(input1, output, norm_deg) - ctx.save_for_backward(input1, output) - ctx.norm_deg = norm_deg - - return output - - @staticmethod - def backward(ctx, grad_output): - input1, output = ctx.saved_tensors - - grad_input1 = Variable(input1.new(input1.size()).zero_()) - - channelnorm_cuda.backward(input1, output, grad_output.data, - grad_input1.data, ctx.norm_deg) - - return grad_input1, None - - -class ChannelNorm(Module): - - def __init__(self, norm_deg=2): - super(ChannelNorm, self).__init__() - self.norm_deg = norm_deg - - def forward(self, input1): - return ChannelNormFunction.apply(input1, self.norm_deg) - diff --git a/codes/models/layers/channelnorm_package/channelnorm_cuda.cc b/codes/models/layers/channelnorm_package/channelnorm_cuda.cc deleted file mode 100644 index 69d82eb1..00000000 --- a/codes/models/layers/channelnorm_package/channelnorm_cuda.cc +++ /dev/null @@ -1,31 +0,0 @@ -#include -#include - -#include "channelnorm_kernel.cuh" - -int channelnorm_cuda_forward( - at::Tensor& input1, - at::Tensor& output, - int norm_deg) { - - channelnorm_kernel_forward(input1, output, norm_deg); - return 1; -} - - -int channelnorm_cuda_backward( - at::Tensor& input1, - at::Tensor& output, - at::Tensor& gradOutput, - at::Tensor& gradInput1, - int norm_deg) { - - channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg); - return 1; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)"); - m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)"); -} - diff --git a/codes/models/layers/channelnorm_package/channelnorm_kernel.cuh b/codes/models/layers/channelnorm_package/channelnorm_kernel.cuh deleted file mode 100644 index 3e6223f7..00000000 --- a/codes/models/layers/channelnorm_package/channelnorm_kernel.cuh +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include - -void channelnorm_kernel_forward( - at::Tensor& input1, - at::Tensor& output, - int norm_deg); - - -void channelnorm_kernel_backward( - at::Tensor& input1, - at::Tensor& output, - at::Tensor& gradOutput, - at::Tensor& gradInput1, - int norm_deg); diff --git a/codes/models/layers/channelnorm_package/setup.py b/codes/models/layers/channelnorm_package/setup.py deleted file mode 100644 index 5b9e86a4..00000000 --- a/codes/models/layers/channelnorm_package/setup.py +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env python3 -import os -import torch - -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -cxx_args = ['-std=c++11'] - -nvcc_args = [ - '-gencode', 'arch=compute_52,code=sm_52', - '-gencode', 'arch=compute_60,code=sm_60', - '-gencode', 'arch=compute_61,code=sm_61', - '-gencode', 'arch=compute_70,code=sm_70', - '-gencode', 'arch=compute_70,code=compute_70' -] - -setup( - name='channelnorm_cuda', - ext_modules=[ - CUDAExtension('channelnorm_cuda', [ - 'channelnorm_cuda.cc', - 'channelnorm_kernel.cu' - ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) - ], - cmdclass={ - 'build_ext': BuildExtension - }) diff --git a/codes/models/layers/correlation_package/__init__.py b/codes/models/layers/correlation_package/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/models/layers/correlation_package/correlation.py b/codes/models/layers/correlation_package/correlation.py deleted file mode 100644 index 2dcdf6eb..00000000 --- a/codes/models/layers/correlation_package/correlation.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch -from torch.nn.modules.module import Module -from torch.autograd import Function -import correlation_cuda - -class CorrelationFunction(Function): - - @staticmethod - def forward(ctx, input1, input2, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1): - ctx.save_for_backward(input1, input2) - - ctx.pad_size = pad_size - ctx.kernel_size = kernel_size - ctx.max_displacement = max_displacement - ctx.stride1 = stride1 - ctx.stride2 = stride2 - ctx.corr_multiply = corr_multiply - - with torch.cuda.device_of(input1): - rbot1 = input1.new() - rbot2 = input2.new() - output = input1.new() - - correlation_cuda.forward(input1, input2, rbot1, rbot2, output, - ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply) - - return output - - @staticmethod - def backward(ctx, grad_output): - input1, input2 = ctx.saved_tensors - - with torch.cuda.device_of(input1): - rbot1 = input1.new() - rbot2 = input2.new() - - grad_input1 = input1.new() - grad_input2 = input2.new() - - correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2, - ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply) - - return grad_input1, grad_input2, None, None, None, None, None, None - - -class Correlation(Module): - def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1): - super(Correlation, self).__init__() - self.pad_size = pad_size - self.kernel_size = kernel_size - self.max_displacement = max_displacement - self.stride1 = stride1 - self.stride2 = stride2 - self.corr_multiply = corr_multiply - - def forward(self, input1, input2): - - result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply) - - return result - diff --git a/codes/models/layers/correlation_package/correlation_cuda.cc b/codes/models/layers/correlation_package/correlation_cuda.cc deleted file mode 100644 index feccd652..00000000 --- a/codes/models/layers/correlation_package/correlation_cuda.cc +++ /dev/null @@ -1,173 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "correlation_cuda_kernel.cuh" - -int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply) -{ - - int batchSize = input1.size(0); - - int nInputChannels = input1.size(1); - int inputHeight = input1.size(2); - int inputWidth = input1.size(3); - - int kernel_radius = (kernel_size - 1) / 2; - int border_radius = kernel_radius + max_displacement; - - int paddedInputHeight = inputHeight + 2 * pad_size; - int paddedInputWidth = inputWidth + 2 * pad_size; - - int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); - - int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); - int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); - - rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); - rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); - output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); - - rInput1.fill_(0); - rInput2.fill_(0); - output.fill_(0); - - int success = correlation_forward_cuda_kernel( - output, - output.size(0), - output.size(1), - output.size(2), - output.size(3), - output.stride(0), - output.stride(1), - output.stride(2), - output.stride(3), - input1, - input1.size(1), - input1.size(2), - input1.size(3), - input1.stride(0), - input1.stride(1), - input1.stride(2), - input1.stride(3), - input2, - input2.size(1), - input2.stride(0), - input2.stride(1), - input2.stride(2), - input2.stride(3), - rInput1, - rInput2, - pad_size, - kernel_size, - max_displacement, - stride1, - stride2, - corr_type_multiply, - at::cuda::getCurrentCUDAStream() - //at::globalContext().getCurrentCUDAStream() - ); - - //check for errors - if (!success) { - AT_ERROR("CUDA call failed"); - } - - return 1; - -} - -int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, - at::Tensor& gradInput1, at::Tensor& gradInput2, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply) -{ - - int batchSize = input1.size(0); - int nInputChannels = input1.size(1); - int paddedInputHeight = input1.size(2)+ 2 * pad_size; - int paddedInputWidth = input1.size(3)+ 2 * pad_size; - - int height = input1.size(2); - int width = input1.size(3); - - rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); - rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); - gradInput1.resize_({batchSize, nInputChannels, height, width}); - gradInput2.resize_({batchSize, nInputChannels, height, width}); - - rInput1.fill_(0); - rInput2.fill_(0); - gradInput1.fill_(0); - gradInput2.fill_(0); - - int success = correlation_backward_cuda_kernel(gradOutput, - gradOutput.size(0), - gradOutput.size(1), - gradOutput.size(2), - gradOutput.size(3), - gradOutput.stride(0), - gradOutput.stride(1), - gradOutput.stride(2), - gradOutput.stride(3), - input1, - input1.size(1), - input1.size(2), - input1.size(3), - input1.stride(0), - input1.stride(1), - input1.stride(2), - input1.stride(3), - input2, - input2.stride(0), - input2.stride(1), - input2.stride(2), - input2.stride(3), - gradInput1, - gradInput1.stride(0), - gradInput1.stride(1), - gradInput1.stride(2), - gradInput1.stride(3), - gradInput2, - gradInput2.size(1), - gradInput2.stride(0), - gradInput2.stride(1), - gradInput2.stride(2), - gradInput2.stride(3), - rInput1, - rInput2, - pad_size, - kernel_size, - max_displacement, - stride1, - stride2, - corr_type_multiply, - at::cuda::getCurrentCUDAStream() - //at::globalContext().getCurrentCUDAStream() - ); - - if (!success) { - AT_ERROR("CUDA call failed"); - } - - return 1; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); - m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); -} - diff --git a/codes/models/layers/correlation_package/correlation_cuda_kernel.cuh b/codes/models/layers/correlation_package/correlation_cuda_kernel.cuh deleted file mode 100644 index 1586d3af..00000000 --- a/codes/models/layers/correlation_package/correlation_cuda_kernel.cuh +++ /dev/null @@ -1,91 +0,0 @@ -#pragma once - -#include -#include -#include - -int correlation_forward_cuda_kernel(at::Tensor& output, - int ob, - int oc, - int oh, - int ow, - int osb, - int osc, - int osh, - int osw, - - at::Tensor& input1, - int ic, - int ih, - int iw, - int isb, - int isc, - int ish, - int isw, - - at::Tensor& input2, - int gc, - int gsb, - int gsc, - int gsh, - int gsw, - - at::Tensor& rInput1, - at::Tensor& rInput2, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply, - cudaStream_t stream); - - -int correlation_backward_cuda_kernel( - at::Tensor& gradOutput, - int gob, - int goc, - int goh, - int gow, - int gosb, - int gosc, - int gosh, - int gosw, - - at::Tensor& input1, - int ic, - int ih, - int iw, - int isb, - int isc, - int ish, - int isw, - - at::Tensor& input2, - int gsb, - int gsc, - int gsh, - int gsw, - - at::Tensor& gradInput1, - int gisb, - int gisc, - int gish, - int gisw, - - at::Tensor& gradInput2, - int ggc, - int ggsb, - int ggsc, - int ggsh, - int ggsw, - - at::Tensor& rInput1, - at::Tensor& rInput2, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply, - cudaStream_t stream); diff --git a/codes/models/layers/correlation_package/setup.py b/codes/models/layers/correlation_package/setup.py deleted file mode 100644 index 48b7d73a..00000000 --- a/codes/models/layers/correlation_package/setup.py +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env python3 -import os -import torch - -from setuptools import setup, find_packages -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -cxx_args = ['-std=c++11'] - -nvcc_args = [ - '-gencode', 'arch=compute_50,code=sm_50', - '-gencode', 'arch=compute_52,code=sm_52', - '-gencode', 'arch=compute_60,code=sm_60', - '-gencode', 'arch=compute_61,code=sm_61', - '-gencode', 'arch=compute_70,code=sm_70', - '-gencode', 'arch=compute_70,code=compute_70' -] - -setup( - name='correlation_cuda', - ext_modules=[ - CUDAExtension('correlation_cuda', [ - 'correlation_cuda.cc', - 'correlation_cuda_kernel.cu' - ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) - ], - cmdclass={ - 'build_ext': BuildExtension - }) diff --git a/codes/models/layers/resample2d_package/__init__.py b/codes/models/layers/resample2d_package/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/models/layers/resample2d_package/resample2d.py b/codes/models/layers/resample2d_package/resample2d.py deleted file mode 100644 index 92ea0d01..00000000 --- a/codes/models/layers/resample2d_package/resample2d.py +++ /dev/null @@ -1,49 +0,0 @@ -from torch.nn.modules.module import Module -from torch.autograd import Function, Variable -import resample2d_cuda - -class Resample2dFunction(Function): - - @staticmethod - def forward(ctx, input1, input2, kernel_size=1, bilinear= True): - assert input1.is_contiguous() - assert input2.is_contiguous() - - ctx.save_for_backward(input1, input2) - ctx.kernel_size = kernel_size - ctx.bilinear = bilinear - - _, d, _, _ = input1.size() - b, _, h, w = input2.size() - output = input1.new(b, d, h, w).zero_() - - resample2d_cuda.forward(input1, input2, output, kernel_size, bilinear) - - return output - - @staticmethod - def backward(ctx, grad_output): - grad_output = grad_output.contiguous() - assert grad_output.is_contiguous() - - input1, input2 = ctx.saved_tensors - - grad_input1 = Variable(input1.new(input1.size()).zero_()) - grad_input2 = Variable(input1.new(input2.size()).zero_()) - - resample2d_cuda.backward(input1, input2, grad_output.data, - grad_input1.data, grad_input2.data, - ctx.kernel_size, ctx.bilinear) - - return grad_input1, grad_input2, None, None - -class Resample2d(Module): - - def __init__(self, kernel_size=1, bilinear = True): - super(Resample2d, self).__init__() - self.kernel_size = kernel_size - self.bilinear = bilinear - - def forward(self, input1, input2): - input1_c = input1.contiguous() - return Resample2dFunction.apply(input1_c, input2, self.kernel_size, self.bilinear) diff --git a/codes/models/layers/resample2d_package/resample2d_cuda.cc b/codes/models/layers/resample2d_package/resample2d_cuda.cc deleted file mode 100644 index 75cc6260..00000000 --- a/codes/models/layers/resample2d_package/resample2d_cuda.cc +++ /dev/null @@ -1,32 +0,0 @@ -#include -#include - -#include "resample2d_kernel.cuh" - -int resample2d_cuda_forward( - at::Tensor& input1, - at::Tensor& input2, - at::Tensor& output, - int kernel_size, bool bilinear) { - resample2d_kernel_forward(input1, input2, output, kernel_size, bilinear); - return 1; -} - -int resample2d_cuda_backward( - at::Tensor& input1, - at::Tensor& input2, - at::Tensor& gradOutput, - at::Tensor& gradInput1, - at::Tensor& gradInput2, - int kernel_size, bool bilinear) { - resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size, bilinear); - return 1; -} - - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)"); - m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)"); -} - diff --git a/codes/models/layers/resample2d_package/resample2d_kernel.cuh b/codes/models/layers/resample2d_package/resample2d_kernel.cuh deleted file mode 100644 index a2595159..00000000 --- a/codes/models/layers/resample2d_package/resample2d_kernel.cuh +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include - -void resample2d_kernel_forward( - at::Tensor& input1, - at::Tensor& input2, - at::Tensor& output, - int kernel_size, - bool bilinear); - -void resample2d_kernel_backward( - at::Tensor& input1, - at::Tensor& input2, - at::Tensor& gradOutput, - at::Tensor& gradInput1, - at::Tensor& gradInput2, - int kernel_size, - bool bilinear); \ No newline at end of file diff --git a/codes/models/layers/resample2d_package/setup.py b/codes/models/layers/resample2d_package/setup.py deleted file mode 100644 index bbedb255..00000000 --- a/codes/models/layers/resample2d_package/setup.py +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env python3 -import os -import torch - -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -cxx_args = ['-std=c++11'] - -nvcc_args = [ - '-gencode', 'arch=compute_50,code=sm_50', - '-gencode', 'arch=compute_52,code=sm_52', - '-gencode', 'arch=compute_60,code=sm_60', - '-gencode', 'arch=compute_61,code=sm_61', - '-gencode', 'arch=compute_70,code=sm_70', - '-gencode', 'arch=compute_70,code=compute_70' -] - -setup( - name='resample2d_cuda', - ext_modules=[ - CUDAExtension('resample2d_cuda', [ - 'resample2d_cuda.cc', - 'resample2d_kernel.cu' - ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) - ], - cmdclass={ - 'build_ext': BuildExtension - }) diff --git a/codes/models/loss.py b/codes/models/loss.py index 6b8d641d..ef387144 100644 --- a/codes/models/loss.py +++ b/codes/models/loss.py @@ -49,7 +49,7 @@ class GANLoss(nn.Module): # Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL -# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see data_scripts/compute_fdpl_perceptual_weights.py +# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see scripts/compute_fdpl_perceptual_weights.py # In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs). class FDPLLoss(nn.Module): """ diff --git a/codes/models/novograd.py b/codes/models/novograd.py deleted file mode 100644 index 374479ad..00000000 --- a/codes/models/novograd.py +++ /dev/null @@ -1,71 +0,0 @@ -# Author Masashi Kimura (Convergence Lab.) -import torch -from torch import optim -import math - -class NovoGrad(optim.Optimizer): - def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super(NovoGrad, self).__init__(params, defaults) - self._lr = lr - self._beta1 = betas[0] - self._beta2 = betas[1] - self._eps = eps - self._wd = weight_decay - self._grad_averaging = grad_averaging - - self._momentum_initialized = False - - def step(self, closure=None): - loss = None - if closure is not None: - loss = closure() - - if not self._momentum_initialized: - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - state = self.state[p] - grad = p.grad.data - if grad.is_sparse: - raise RuntimeError('NovoGrad does not support sparse gradients') - - v = torch.norm(grad)**2 - m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data - state['step'] = 0 - state['v'] = v - state['m'] = m - state['grad_ema'] = None - self._momentum_initialized = True - - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - state = self.state[p] - state['step'] += 1 - - step, v, m = state['step'], state['v'], state['m'] - grad_ema = state['grad_ema'] - - grad = p.grad.data - g2 = torch.norm(grad)**2 - grad_ema = g2 if grad_ema is None else grad_ema * \ - self._beta2 + g2*(1. - self._beta2) - grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) - - if self._grad_averaging: - grad *= (1. - self._beta1) - - g2 = torch.norm(grad)**2 - v = self._beta2*v + (1. - self._beta2)*g2 - m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd*p.data) - bias_correction1 = 1 - self._beta1 ** step - bias_correction2 = 1 - self._beta2 ** step - step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 - - state['v'], state['m'] = v, m - state['grad_ema'] = grad_ema - p.data.add_(-step_size, m) - return loss \ No newline at end of file diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 9e3dd430..6f4642ab 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -1,8 +1,6 @@ import torch.nn from models.archs.SPSR_arch import ImageGradientNoPadding -from data.weight_scheduler import get_scheduler_for_opt -from utils.util import checkpoint -import torchvision.utils as utils +from utils.weight_scheduler import get_scheduler_for_opt #from models.steps.recursive_gen_injectors import ImageFlowInjector from models.steps.losses import extract_params_from_state diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 6cf21b09..e4b9bfb3 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -6,7 +6,6 @@ import torch from apex import amp from collections import OrderedDict from .injectors import create_injector -from models.novograd import NovoGrad from utils.util import recursively_detach logger = logging.getLogger('base') diff --git a/codes/options/__init__.py b/codes/options/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/options/test/test_ESRGAN.yml b/codes/options/test/test_ESRGAN.yml deleted file mode 100644 index 3522f217..00000000 --- a/codes/options/test/test_ESRGAN.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: RRDB_ESRGAN_x4 -suffix: ~ # add suffix to saved images -model: sr -distortion: sr -scale: 4 -crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels -gpu_ids: [0] - -datasets: - test_1: # the 1st test dataset - name: set5 - mode: LQGT - dataroot_GT: ../datasets/val_set5/Set5 - dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 - test_2: # the 2st test dataset - name: set14 - mode: LQGT - dataroot_GT: ../datasets/val_set14/Set14 - dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 - -#### network structures -network_G: - which_model_G: RRDBNet - in_nc: 3 - out_nc: 3 - nf: 64 - nb: 23 - upscale: 4 - -#### path -path: - pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth diff --git a/codes/options/test/test_ESRGAN_woGT.yml b/codes/options/test/test_ESRGAN_woGT.yml deleted file mode 100644 index 24ab5b62..00000000 --- a/codes/options/test/test_ESRGAN_woGT.yml +++ /dev/null @@ -1,26 +0,0 @@ -name: RRDB_ESRGAN_x4 -suffix: ~ # add suffix to saved images -model: sr -distortion: sr -scale: 4 -crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels -gpu_ids: [0] - -datasets: - test_1: # the 1st test dataset - name: set14 - mode: LQ - dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 - -#### network structures -network_G: - which_model_G: RRDBNet - in_nc: 3 - out_nc: 3 - nf: 64 - nb: 23 - upscale: 4 - -#### path -path: - pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth diff --git a/codes/options/test/test_SRGAN.yml b/codes/options/test/test_SRGAN.yml deleted file mode 100644 index 21eea625..00000000 --- a/codes/options/test/test_SRGAN.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: MSRGANx4 -suffix: ~ # add suffix to saved images -model: sr -distortion: sr -scale: 4 -crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels -gpu_ids: [0] - -datasets: - test_1: # the 1st test dataset - name: set5 - mode: LQGT - dataroot_GT: ../datasets/val_set5/Set5 - dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 - test_2: # the 2st test dataset - name: set14 - mode: LQGT - dataroot_GT: ../datasets/val_set14/Set14 - dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 - -#### network structures -network_G: - which_model_G: MSRResNet - in_nc: 3 - out_nc: 3 - nf: 64 - nb: 16 - upscale: 4 - -#### path -path: - pretrain_model_G: ../experiments/pretrained_models/MSRGANx4.pth diff --git a/codes/options/test/test_SRResNet.yml b/codes/options/test/test_SRResNet.yml deleted file mode 100644 index b30b3b44..00000000 --- a/codes/options/test/test_SRResNet.yml +++ /dev/null @@ -1,48 +0,0 @@ -name: MSRResNetx4 -suffix: ~ # add suffix to saved images -model: sr -distortion: sr -scale: 4 -crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels -gpu_ids: [0] - -datasets: - test_1: # the 1st test dataset - name: set5 - mode: LQGT - dataroot_GT: ../datasets/val_set5/Set5 - dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 - test_2: # the 2st test dataset - name: set14 - mode: LQGT - dataroot_GT: ../datasets/val_set14/Set14 - dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 - test_3: - name: bsd100 - mode: LQGT - dataroot_GT: ../datasets/BSD/BSDS100 - dataroot_LQ: ../datasets/BSD/BSDS100_bicLRx4 - test_4: - name: urban100 - mode: LQGT - dataroot_GT: ../datasets/urban100 - dataroot_LQ: ../datasets/urban100_bicLRx4 - test_5: - name: div2k100 - mode: LQGT - dataroot_GT: ../datasets/DIV2K100/DIV2K_valid_HR - dataroot_LQ: ../datasets/DIV2K100/DIV2K_valid_bicLRx4 - - -#### network structures -network_G: - which_model_G: MSRResNet - in_nc: 3 - out_nc: 3 - nf: 64 - nb: 16 - upscale: 4 - -#### path -path: - pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth diff --git a/codes/options/train/train_EDVR_M.yml b/codes/options/train/train_EDVR_M.yml deleted file mode 100644 index 8a1c55df..00000000 --- a/codes/options/train/train_EDVR_M.yml +++ /dev/null @@ -1,80 +0,0 @@ -#### general settings -name: 002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new -use_tb_logger: true -model: video_base -distortion: sr -scale: 4 -gpu_ids: [0,1,2,3,4,5,6,7] - -#### datasets -datasets: - train: - name: REDS - mode: REDS - interval_list: [1] - random_reverse: false - border_mode: false - dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb - dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb - cache_keys: ~ - - N_frames: 5 - use_shuffle: true - n_workers: 3 # per GPU - batch_size: 32 - target_size: 256 - LQ_size: 64 - use_flip: true - use_rot: true - color: RGB - val: - name: REDS4 - mode: video_test - dataroot_GT: ../datasets/REDS4/GT - dataroot_LQ: ../datasets/REDS4/sharp_bicubic - cache_data: True - N_frames: 5 - padding: new_info - -#### network structures -network_G: - which_model_G: EDVR - nf: 64 - nframes: 5 - groups: 8 - front_RBs: 5 - back_RBs: 10 - predeblur: false - HR_in: false - w_TSA: true - -#### path -path: - pretrain_model_G: ../experiments/pretrained_models/EDVR_REDS_SR_M_woTSA.pth - strict_load: false - resume_state: ~ - -#### training settings: learning rate scheme, loss -train: - lr_G: !!float 4e-4 - lr_scheme: CosineAnnealingLR_Restart - beta1: 0.9 - beta2: 0.99 - niter: 600000 - ft_tsa_only: 50000 - warmup_iter: -1 # -1: no warm up - T_period: [50000, 100000, 150000, 150000, 150000] - restarts: [50000, 150000, 300000, 450000] - restart_weights: [1, 1, 1, 1] - eta_min: !!float 1e-7 - - pixel_criterion: cb - pixel_weight: 1.0 - val_freq: !!float 5e3 - - manual_seed: 0 - -#### logger -logger: - print_freq: 100 - save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_EDVR_woTSA_M.yml b/codes/options/train/train_EDVR_woTSA_M.yml deleted file mode 100644 index cd30f2ab..00000000 --- a/codes/options/train/train_EDVR_woTSA_M.yml +++ /dev/null @@ -1,71 +0,0 @@ -#### general settings -name: 001_EDVRwoTSA_scratch_lr4e-4_600k_REDS_LrCAR4S -use_tb_logger: true -model: video_base -distortion: sr -scale: 4 -gpu_ids: [0,1,2,3,4,5,6,7] - -#### datasets -datasets: - train: - name: REDS - mode: REDS - interval_list: [1] - random_reverse: false - border_mode: false - dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb - dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb - cache_keys: ~ - - N_frames: 5 - use_shuffle: true - n_workers: 3 # per GPU - batch_size: 32 - target_size: 256 - LQ_size: 64 - use_flip: true - use_rot: true - color: RGB - -#### network structures -network_G: - which_model_G: EDVR - nf: 64 - nframes: 5 - groups: 8 - front_RBs: 5 - back_RBs: 10 - predeblur: false - HR_in: false - w_TSA: false - -#### path -path: - pretrain_model_G: ~ - strict_load: true - resume_state: ~ - -#### training settings: learning rate scheme, loss -train: - lr_G: !!float 4e-4 - lr_scheme: CosineAnnealingLR_Restart - beta1: 0.9 - beta2: 0.99 - niter: 600000 - warmup_iter: -1 # -1: no warm up - T_period: [150000, 150000, 150000, 150000] - restarts: [150000, 300000, 450000] - restart_weights: [1, 1, 1] - eta_min: !!float 1e-7 - - pixel_criterion: cb - pixel_weight: 1.0 - val_freq: !!float 5e3 - - manual_seed: 0 - -#### logger -logger: - print_freq: 100 - save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_ESRGAN.yml b/codes/options/train/train_ESRGAN.yml deleted file mode 100644 index 8a825b2b..00000000 --- a/codes/options/train/train_ESRGAN.yml +++ /dev/null @@ -1,82 +0,0 @@ -#### general settings -name: 003_RRDB_ESRGANx4_DIV2K -use_tb_logger: true -model: srgan -distortion: sr -scale: 4 -gpu_ids: [0] -amp_opt_level: O1 - -#### datasets -datasets: - train: - name: DIV2K - mode: LQGT - dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub - dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4 - - use_shuffle: true - n_workers: 16 # per GPU - batch_size: 16 - target_size: 128 - use_flip: true - use_rot: true - color: RGB - val: - name: div2kval - mode: LQGT - dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr - dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic - -#### network structures -network_G: - which_model_G: ResGen - nf: 256 -network_D: - which_model_D: discriminator_resnet_passthrough - nf: 42 - -#### path -path: - pretrain_model_G: ~ - strict_load: true - resume_state: ~ - -#### training settings: learning rate scheme, loss -train: - lr_G: !!float 1e-4 - weight_decay_G: 0 - beta1_G: 0.9 - beta2_G: 0.99 - lr_D: !!float 1e-4 - weight_decay_D: 0 - beta1_D: 0.9 - beta2_D: 0.99 - lr_scheme: MultiStepLR - - niter: 400000 - warmup_iter: -1 # no warm up - lr_steps: [50000, 100000, 200000, 300000] - lr_gamma: 0.5 - mega_batch_factor: 1 - - pixel_criterion: l1 - pixel_weight: !!float 1e-2 - feature_criterion: l1 - feature_weight: 1 - feature_weight_decay: .98 - feature_weight_decay_steps: 500 - feature_weight_minimum: .1 - gan_type: gan # gan | ragan - gan_weight: !!float 5e-3 - - D_update_ratio: 2 - D_init_iters: 0 - - manual_seed: 10 - val_freq: !!float 5e2 - -#### logger -logger: - print_freq: 50 - save_checkpoint_freq: !!float 5e2 diff --git a/codes/options/train/train_ESRGAN_res.yml b/codes/options/train/train_ESRGAN_res.yml deleted file mode 100644 index 0e42b883..00000000 --- a/codes/options/train/train_ESRGAN_res.yml +++ /dev/null @@ -1,85 +0,0 @@ -#### general settings -name: esrgan_res -use_tb_logger: true -model: srgan -distortion: sr -scale: 4 -gpu_ids: [0] -amp_opt_level: O1 - -#### datasets -datasets: - train: - name: DIV2K - mode: LQGT - dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub - dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4 - - use_shuffle: true - n_workers: 0 # per GPU - batch_size: 24 - target_size: 128 - use_flip: true - use_rot: true - color: RGB - val: - name: div2kval - mode: LQGT - dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr - dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic - -#### network structures -network_G: - which_model_G: ResGen - nf: 256 - nb_denoiser: 2 - nb_upsampler: 28 -network_D: - which_model_D: discriminator_resnet_passthrough - nf: 42 - -#### path -path: - #pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth - #pretrain_model_D: ~ - strict_load: true - resume_state: ../experiments/esrgan_res/training_state/15500.state - -#### training settings: learning rate scheme, loss -train: - lr_G: !!float 1e-4 - weight_decay_G: 0 - beta1_G: 0.9 - beta2_G: 0.99 - lr_D: !!float 1e-4 - weight_decay_D: 0 - beta1_D: 0.9 - beta2_D: 0.99 - lr_scheme: MultiStepLR - - niter: 400000 - warmup_iter: -1 # no warm up - lr_steps: [20000, 40000, 50000, 60000] - lr_gamma: 0.5 - mega_batch_factor: 2 - - pixel_criterion: l1 - pixel_weight: !!float 1e-2 - feature_criterion: l1 - feature_weight: 1 - feature_weight_decay: 1 - feature_weight_decay_steps: 500 - feature_weight_minimum: 1 - gan_type: gan # gan | ragan - gan_weight: !!float 1e-2 - - D_update_ratio: 2 - D_init_iters: -1 - - manual_seed: 10 - val_freq: !!float 5e2 - -#### logger -logger: - print_freq: 50 - save_checkpoint_freq: !!float 5e2 diff --git a/codes/options/train/train_SRGAN.yml b/codes/options/train/train_SRGAN.yml deleted file mode 100644 index 6835601c..00000000 --- a/codes/options/train/train_SRGAN.yml +++ /dev/null @@ -1,85 +0,0 @@ -# Not exactly the same as SRGAN in -# With 16 Residual blocks w/o BN - -#### general settings -name: 002_SRGANx4_MSRResNetx4Ini_DIV2K -use_tb_logger: true -model: srgan -distortion: sr -scale: 4 -gpu_ids: [1] - -#### datasets -datasets: - train: - name: DIV2K - mode: LQGT - dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb - dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb - - use_shuffle: true - n_workers: 6 # per GPU - batch_size: 16 - target_size: 128 - use_flip: true - use_rot: true - color: RGB - val: - name: val_set14 - mode: LQGT - dataroot_GT: ../datasets/val_set14/Set14 - dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 - -#### network structures -network_G: - which_model_G: MSRResNet - in_nc: 3 - out_nc: 3 - nf: 64 - nb: 16 - upscale: 4 -network_D: - which_model_D: discriminator_vgg_128 - in_nc: 3 - nf: 64 - -#### path -path: - pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth - strict_load: true - resume_state: ~ - -#### training settings: learning rate scheme, loss -train: - lr_G: !!float 1e-4 - weight_decay_G: 0 - beta1_G: 0.9 - beta2_G: 0.99 - lr_D: !!float 1e-4 - weight_decay_D: 0 - beta1_D: 0.9 - beta2_D: 0.99 - lr_scheme: MultiStepLR - - niter: 400000 - warmup_iter: -1 # no warm up - lr_steps: [50000, 100000, 200000, 300000] - lr_gamma: 0.5 - - pixel_criterion: l1 - pixel_weight: !!float 1e-2 - feature_criterion: l1 - feature_weight: 1 - gan_type: gan # gan | ragan - gan_weight: !!float 5e-3 - - D_update_ratio: 1 - D_init_iters: 0 - - manual_seed: 10 - val_freq: !!float 5e3 - -#### logger -logger: - print_freq: 100 - save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_SRResNet.yml b/codes/options/train/train_SRResNet.yml deleted file mode 100644 index 15468dce..00000000 --- a/codes/options/train/train_SRResNet.yml +++ /dev/null @@ -1,70 +0,0 @@ -# Not exactly the same as SRResNet in -# With 16 Residual blocks w/o BN - -#### general settings -name: 001_MSRResNetx4_scratch_DIV2K -use_tb_logger: true -model: sr -distortion: sr -scale: 4 -gpu_ids: [0] - -#### datasets -datasets: - train: - name: DIV2K - mode: LQGT - dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb - dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb - - use_shuffle: true - n_workers: 6 # per GPU - batch_size: 16 - target_size: 128 - use_flip: true - use_rot: true - color: RGB - val: - name: val_set5 - mode: LQGT - dataroot_GT: ../datasets/val_set5/Set5 - dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 - -#### network structures -network_G: - which_model_G: MSRResNet - in_nc: 3 - out_nc: 3 - nf: 64 - nb: 16 - upscale: 4 - -#### path -path: - pretrain_model_G: ~ - strict_load: true - resume_state: ~ - -#### training settings: learning rate scheme, loss -train: - lr_G: !!float 2e-4 - lr_scheme: CosineAnnealingLR_Restart - beta1: 0.9 - beta2: 0.99 - niter: 1000000 - warmup_iter: -1 # no warm up - T_period: [250000, 250000, 250000, 250000] - restarts: [250000, 500000, 750000] - restart_weights: [1, 1, 1] - eta_min: !!float 1e-7 - - pixel_criterion: l1 - pixel_weight: 1.0 - - manual_seed: 10 - val_freq: !!float 5e3 - -#### logger -logger: - print_freq: 100 - save_checkpoint_freq: !!float 5e3 diff --git a/codes/process_video.py b/codes/process_video.py index 96ec23bf..6f2279e9 100644 --- a/codes/process_video.py +++ b/codes/process_video.py @@ -11,7 +11,7 @@ import torchvision.transforms.functional as F from PIL import Image from tqdm import tqdm -import options.options as option +from utils import options as option import utils.util as util from data import create_dataloader from models import create_model diff --git a/codes/run_scripts.sh b/codes/run_scripts.sh deleted file mode 100644 index 3e7c4945..00000000 --- a/codes/run_scripts.sh +++ /dev/null @@ -1,10 +0,0 @@ -# single GPU training (image SR) -python train.py -opt options/train/train_SRResNet.yml -python train.py -opt options/train/train_SRGAN.yml -python train.py -opt options/train/train_ESRGAN.yml - - -# distributed training (video SR) -# 8 GPUs -python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_woTSA_M.yml --launcher pytorch -python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_M.yml --launcher pytorch \ No newline at end of file diff --git a/codes/scripts/back_projection/backprojection.m b/codes/scripts/back_projection/backprojection.m deleted file mode 100644 index 496d93f4..00000000 --- a/codes/scripts/back_projection/backprojection.m +++ /dev/null @@ -1,20 +0,0 @@ -function [im_h] = backprojection(im_h, im_l, maxIter) - -[row_l, col_l,~] = size(im_l); -[row_h, col_h,~] = size(im_h); - -p = fspecial('gaussian', 5, 1); -p = p.^2; -p = p./sum(p(:)); - -im_l = double(im_l); -im_h = double(im_h); - -for ii = 1:maxIter - im_l_s = imresize(im_h, [row_l, col_l], 'bicubic'); - im_diff = im_l - im_l_s; - im_diff = imresize(im_diff, [row_h, col_h], 'bicubic'); - im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same'); - im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same'); - im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same'); -end diff --git a/codes/scripts/back_projection/main_bp.m b/codes/scripts/back_projection/main_bp.m deleted file mode 100644 index 40c137ed..00000000 --- a/codes/scripts/back_projection/main_bp.m +++ /dev/null @@ -1,22 +0,0 @@ -clear; close all; clc; - -LR_folder = './LR'; % LR -preout_folder = './results'; % pre output -save_folder = './results_20bp'; -filepaths = dir(fullfile(preout_folder, '*.png')); -max_iter = 20; - -if ~ exist(save_folder, 'dir') - mkdir(save_folder); -end - -for idx_im = 1:length(filepaths) - fprintf([num2str(idx_im) '\n']); - im_name = filepaths(idx_im).name; - im_LR = im2double(imread(fullfile(LR_folder, im_name))); - im_out = im2double(imread(fullfile(preout_folder, im_name))); - %tic - im_out = backprojection(im_out, im_LR, max_iter); - %toc - imwrite(im_out, fullfile(save_folder, im_name)); -end diff --git a/codes/scripts/back_projection/main_reverse_filter.m b/codes/scripts/back_projection/main_reverse_filter.m deleted file mode 100644 index 63f2edcf..00000000 --- a/codes/scripts/back_projection/main_reverse_filter.m +++ /dev/null @@ -1,25 +0,0 @@ -clear; close all; clc; - -LR_folder = './LR'; % LR -preout_folder = './results'; % pre output -save_folder = './results_20if'; -filepaths = dir(fullfile(preout_folder, '*.png')); -max_iter = 20; - -if ~ exist(save_folder, 'dir') - mkdir(save_folder); -end - -for idx_im = 1:length(filepaths) - fprintf([num2str(idx_im) '\n']); - im_name = filepaths(idx_im).name; - im_LR = im2double(imread(fullfile(LR_folder, im_name))); - im_out = im2double(imread(fullfile(preout_folder, im_name))); - J = imresize(im_LR,4,'bicubic'); - %tic - for m = 1:max_iter - im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic')); - end - %toc - imwrite(im_out, fullfile(save_folder, im_name)); -end diff --git a/codes/data_scripts/compute_fdpl_perceptual_weights.py b/codes/scripts/compute_fdpl_perceptual_weights.py similarity index 96% rename from codes/data_scripts/compute_fdpl_perceptual_weights.py rename to codes/scripts/compute_fdpl_perceptual_weights.py index 42d4ce2f..1a6506b2 100644 --- a/codes/data_scripts/compute_fdpl_perceptual_weights.py +++ b/codes/scripts/compute_fdpl_perceptual_weights.py @@ -1,14 +1,10 @@ import torch -import os -from PIL import Image import numpy as np -import options.options as option +from utils import options as option from data import create_dataloader, create_dataset import math from tqdm import tqdm -from torchvision import transforms from utils.fdpl_util import dct_2d, extract_patches_2d -import random import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable from utils.colors import rgb2ycbcr diff --git a/codes/data_scripts/create_lmdb.py b/codes/scripts/create_lmdb.py similarity index 100% rename from codes/data_scripts/create_lmdb.py rename to codes/scripts/create_lmdb.py diff --git a/codes/data_scripts/extract_subimages.py b/codes/scripts/extract_subimages.py similarity index 100% rename from codes/data_scripts/extract_subimages.py rename to codes/scripts/extract_subimages.py diff --git a/codes/data_scripts/extract_subimages_with_ref.py b/codes/scripts/extract_subimages_with_ref.py similarity index 100% rename from codes/data_scripts/extract_subimages_with_ref.py rename to codes/scripts/extract_subimages_with_ref.py diff --git a/codes/recover_tensorboard_log.py b/codes/scripts/recover_tensorboard_log.py similarity index 100% rename from codes/recover_tensorboard_log.py rename to codes/scripts/recover_tensorboard_log.py diff --git a/codes/data_scripts/rename.py b/codes/scripts/rename.py similarity index 100% rename from codes/data_scripts/rename.py rename to codes/scripts/rename.py diff --git a/codes/data_scripts/test_dataloader.py b/codes/scripts/test_dataloader.py similarity index 100% rename from codes/data_scripts/test_dataloader.py rename to codes/scripts/test_dataloader.py diff --git a/codes/scripts/transfer_params_MSRResNet.py b/codes/scripts/transfer_params_MSRResNet.py deleted file mode 100644 index 70dafa4d..00000000 --- a/codes/scripts/transfer_params_MSRResNet.py +++ /dev/null @@ -1,27 +0,0 @@ -import os.path as osp -import sys -import torch -try: - sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) - import models.archs.SRResNet_arch as SRResNet_arch -except ImportError: - pass - -pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth') -crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3) -crt_net = crt_model.state_dict() - -for k, v in crt_net.items(): - if k in pretrained_net and 'upconv1' not in k: - crt_net[k] = pretrained_net[k] - print('replace ... ', k) - -# x4 -> x3 -crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2 -crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2 -crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2 -crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2 -crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2 -crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2 - -torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth') diff --git a/codes/data_scripts/use_discriminator_as_filter.py b/codes/scripts/use_discriminator_as_filter.py similarity index 87% rename from codes/data_scripts/use_discriminator_as_filter.py rename to codes/scripts/use_discriminator_as_filter.py index 66443772..94a69eb3 100644 --- a/codes/data_scripts/use_discriminator_as_filter.py +++ b/codes/scripts/use_discriminator_as_filter.py @@ -2,21 +2,14 @@ import os.path as osp import logging import time import argparse -from collections import OrderedDict import os -import options.options as option +from utils import options as option import utils.util as util -from data.util import bgr2ycbcr -import models.archs.SwitchedResidualGenerator_arch as srg -from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb -from switched_conv.switched_conv import compute_attention_specificity from data import create_dataset, create_dataloader from models import create_model from tqdm import tqdm import torch -import models.networks as networks -import shutil import torchvision diff --git a/codes/data_scripts/validate_data.py b/codes/scripts/validate_data.py similarity index 96% rename from codes/data_scripts/validate_data.py rename to codes/scripts/validate_data.py index 9789d7bc..e96084e0 100644 --- a/codes/data_scripts/validate_data.py +++ b/codes/scripts/validate_data.py @@ -5,10 +5,8 @@ import math import argparse import random import torch -import options.options as option -from utils import util +from utils import util, options as option from data import create_dataloader, create_dataset -from time import time from tqdm import tqdm from skimage import io diff --git a/codes/temp/cleanup.sh b/codes/temp/cleanup.sh deleted file mode 100644 index 15f2028e..00000000 --- a/codes/temp/cleanup.sh +++ /dev/null @@ -1,9 +0,0 @@ -rm gen/* -rm hr/* -rm lr/* -rm pix/* -rm ref/* -rm genlr/* -rm genmr/* -rm lr_precorrupt/* -rm ref/* \ No newline at end of file diff --git a/codes/test.py b/codes/test.py index ffd8fa7e..40231805 100644 --- a/codes/test.py +++ b/codes/test.py @@ -1,21 +1,5 @@ -import os.path as osp -import logging -import time -import argparse -from collections import OrderedDict - -import os -import options.options as option -import utils.util as util -from data.util import bgr2ycbcr -import models.archs.SwitchedResidualGenerator_arch as srg -from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb -from switched_conv.switched_conv import compute_attention_specificity -from data import create_dataset, create_dataloader -from models import create_model -from tqdm import tqdm import torch -import models.networks as networks + class CheckpointFunction(torch.autograd.Function): @staticmethod @@ -39,7 +23,7 @@ class CheckpointFunction(torch.autograd.Function): input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) return (None, None) + input_grads -from models.archs.arch_util import ConvGnSilu, UpconvBlock +from models.archs.arch_util import ConvGnSilu import torch.nn as nn if __name__ == "__main__": model = nn.Sequential(ConvGnSilu(3, 64, 3, norm=False), diff --git a/codes/train.py b/codes/train.py index 4568276b..7cf979d9 100644 --- a/codes/train.py +++ b/codes/train.py @@ -3,16 +3,14 @@ import math import argparse import random import logging -import shutil from tqdm import tqdm import torch from data.data_sampler import DistIterSampler -import options.options as option -from utils import util +from utils import util, options as option from data import create_dataloader, create_dataset -from models import create_model +from models.ExtensibleTrainer import ExtensibleTrainer from time import time @@ -159,7 +157,7 @@ def main(): assert train_loader is not None #### create model - model = create_model(opt) + model = ExtensibleTrainer(opt) #### resume training if resume_state: diff --git a/codes/train2.py b/codes/train2.py index 757ec04f..bccb5119 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -3,16 +3,14 @@ import math import argparse import random import logging -import shutil from tqdm import tqdm import torch from data.data_sampler import DistIterSampler -import options.options as option -from utils import util +from models.ExtensibleTrainer import ExtensibleTrainer +from utils import util, options as option from data import create_dataloader, create_dataset -from models import create_model from time import time @@ -159,7 +157,7 @@ def main(): assert train_loader is not None #### create model - model = create_model(opt) + model = ExtensibleTrainer(opt) #### resume training if resume_state: diff --git a/codes/utils/convert_model.py b/codes/utils/convert_model.py index 2f98ba8a..fa0c0caf 100644 --- a/codes/utils/convert_model.py +++ b/codes/utils/convert_model.py @@ -1,7 +1,7 @@ # Tool that can be used to add a new layer into an existing model save file. Primarily useful for "progressive" # models which can be trained piecemeal. -import options.options as option +from utils import options as option from models import create_model import torch import os diff --git a/codes/utils/distill_torchscript.py b/codes/utils/distill_torchscript.py index 0f3aa173..0b2baec0 100644 --- a/codes/utils/distill_torchscript.py +++ b/codes/utils/distill_torchscript.py @@ -1,7 +1,7 @@ import argparse import functools import torch -import options.options as option +from utils import options as option from models.networks import define_G diff --git a/codes/options/options.py b/codes/utils/options.py similarity index 100% rename from codes/options/options.py rename to codes/utils/options.py diff --git a/codes/data/weight_scheduler.py b/codes/utils/weight_scheduler.py similarity index 100% rename from codes/data/weight_scheduler.py rename to codes/utils/weight_scheduler.py