More refactoring

This commit is contained in:
James Betker 2020-12-18 09:18:34 -07:00
parent b905b108da
commit 5640e4efe4
69 changed files with 62 additions and 506 deletions

8
.gitmodules vendored
View File

@ -1,13 +1,9 @@
[submodule "flownet2"] [submodule "flownet2"]
path = flownet2 path = flownet2
url = https://github.com/NVIDIA/flownet2-pytorch.git url = https://github.com/NVIDIA/flownet2-pytorch.git
[submodule "codes/switched_conv"] [submodule "codes/models/switched_conv"]
path = codes/switched_conv path = codes/models/switched_conv
url = https://github.com/neonbjb/SwitchedConvolutions.git url = https://github.com/neonbjb/SwitchedConvolutions.git
[submodule "codes/models/flownet2"] [submodule "codes/models/flownet2"]
path = codes/models/flownet2 path = codes/models/flownet2
url = https://github.com/neonbjb/flownet2-pytorch.git url = https://github.com/neonbjb/flownet2-pytorch.git
branch = master
[submodule "codes/models/archs/flownet2"]
path = codes/models/archs/flownet2
url = https://github.com/neonbjb/flownet2-pytorch.git

View File

@ -1,261 +0,0 @@
function calculate_PSNR_SSIM()
% GT and SR folder
folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5';
folder_SR = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5';
scale = 4;
suffix = ''; % suffix for SR images
test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels
if test_Y
fprintf('Tesing Y channel.\n');
else
fprintf('Tesing RGB channels.\n');
end
filepaths = dir(fullfile(folder_GT, '*.png'));
PSNR_all = zeros(1, length(filepaths));
SSIM_all = zeros(1, length(filepaths));
for idx_im = 1:length(filepaths)
im_name = filepaths(idx_im).name;
im_GT = imread(fullfile(folder_GT, im_name));
im_SR = imread(fullfile(folder_SR, [im_name(1:end-4), suffix, '.png']));
if test_Y % evaluate on Y channel in YCbCr color space
if size(im_GT, 3) == 3
im_GT_YCbCr = rgb2ycbcr(im2double(im_GT));
im_GT_in = im_GT_YCbCr(:,:,1);
im_SR_YCbCr = rgb2ycbcr(im2double(im_SR));
im_SR_in = im_SR_YCbCr(:,:,1);
else
im_GT_in = im2double(im_GT);
im_SR_in = im2double(im_SR);
end
else % evaluate on RGB channels
im_GT_in = im2double(im_GT);
im_SR_in = im2double(im_SR);
end
% calculate PSNR and SSIM
PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale);
SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale);
fprintf('%d.(X%d)%20s: \tPSNR = %f \tSSIM = %f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im));
end
fprintf('\n%26s: \tPSNR = %f \tSSIM = %f\n', '####Average', mean(PSNR_all), mean(SSIM_all));
end
function res = calculate_PSNR(GT, SR, border)
% remove border
GT = GT(border+1:end-border, border+1:end-border, :);
SR = SR(border+1:end-border, border+1:end-border, :);
% calculate PNSR (assume in [0,255])
error = GT(:) - SR(:);
mse = mean(error.^2);
res = 10 * log10(255^2/mse);
end
function res = calculate_SSIM(GT, SR, border)
GT = GT(border+1:end-border, border+1:end-border, :);
SR = SR(border+1:end-border, border+1:end-border, :);
% calculate SSIM
mssim = zeros(1, size(SR, 3));
for i = 1:size(SR,3)
[mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i));
end
res = mean(mssim);
end
function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L)
%========================================================================
%SSIM Index, Version 1.0
%Copyright(c) 2003 Zhou Wang
%All Rights Reserved.
%
%The author is with Howard Hughes Medical Institute, and Laboratory
%for Computational Vision at Center for Neural Science and Courant
%Institute of Mathematical Sciences, New York University.
%
%----------------------------------------------------------------------
%Permission to use, copy, or modify this software and its documentation
%for educational and research purposes only and without fee is hereby
%granted, provided that this copyright notice and the original authors'
%names appear on all copies and supporting documentation. This program
%shall not be used, rewritten, or adapted as the basis of a commercial
%software or hardware product without first obtaining permission of the
%authors. The authors make no representations about the suitability of
%this software for any purpose. It is provided "as is" without express
%or implied warranty.
%----------------------------------------------------------------------
%
%This is an implementation of the algorithm for calculating the
%Structural SIMilarity (SSIM) index between two images. Please refer
%to the following paper:
%
%Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
%quality assessment: From error measurement to structural similarity"
%IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
%
%Kindly report any suggestions or corrections to zhouwang@ieee.org
%
%----------------------------------------------------------------------
%
%Input : (1) img1: the first image being compared
% (2) img2: the second image being compared
% (3) K: constants in the SSIM index formula (see the above
% reference). defualt value: K = [0.01 0.03]
% (4) window: local window for statistics (see the above
% reference). default widnow is Gaussian given by
% window = fspecial('gaussian', 11, 1.5);
% (5) L: dynamic range of the images. default: L = 255
%
%Output: (1) mssim: the mean SSIM index value between 2 images.
% If one of the images being compared is regarded as
% perfect quality, then mssim can be considered as the
% quality measure of the other image.
% If img1 = img2, then mssim = 1.
% (2) ssim_map: the SSIM index map of the test image. The map
% has a smaller size than the input images. The actual size:
% size(img1) - size(window) + 1.
%
%Default Usage:
% Given 2 test images img1 and img2, whose dynamic range is 0-255
%
% [mssim ssim_map] = ssim_index(img1, img2);
%
%Advanced Usage:
% User defined parameters. For example
%
% K = [0.05 0.05];
% window = ones(8);
% L = 100;
% [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
%
%See the results:
%
% mssim %Gives the mssim value
% imshow(max(0, ssim_map).^4) %Shows the SSIM index map
%
%========================================================================
if (nargin < 2 || nargin > 5)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
if (size(img1) ~= size(img2))
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
[M, N] = size(img1);
if (nargin == 2)
if ((M < 11) || (N < 11))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
window = fspecial('gaussian', 11, 1.5); %
K(1) = 0.01; % default settings
K(2) = 0.03; %
L = 255; %
end
if (nargin == 3)
if ((M < 11) || (N < 11))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
window = fspecial('gaussian', 11, 1.5);
L = 255;
if (length(K) == 2)
if (K(1) < 0 || K(2) < 0)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
else
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
end
if (nargin == 4)
[H, W] = size(window);
if ((H*W) < 4 || (H > M) || (W > N))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
L = 255;
if (length(K) == 2)
if (K(1) < 0 || K(2) < 0)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
else
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
end
if (nargin == 5)
[H, W] = size(window);
if ((H*W) < 4 || (H > M) || (W > N))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
if (length(K) == 2)
if (K(1) < 0 || K(2) < 0)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
else
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
end
C1 = (K(1)*L)^2;
C2 = (K(2)*L)^2;
window = window/sum(sum(window));
img1 = double(img1);
img2 = double(img2);
mu1 = filter2(window, img1, 'valid');
mu2 = filter2(window, img2, 'valid');
mu1_sq = mu1.*mu1;
mu2_sq = mu2.*mu2;
mu1_mu2 = mu1.*mu2;
sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
if (C1 > 0 && C2 > 0)
ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
else
numerator1 = 2*mu1_mu2 + C1;
numerator2 = 2*sigma12 + C2;
denominator1 = mu1_sq + mu2_sq + C1;
denominator2 = sigma1_sq + sigma2_sq + C2;
ssim_map = ones(size(mu1));
index = (denominator1.*denominator2 > 0);
ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
index = (denominator1 ~= 0) & (denominator2 == 0);
ssim_map(index) = numerator1(index)./denominator1(index);
end
mssim = mean2(ssim_map);
end

View File

@ -1,147 +0,0 @@
'''
calculate the PSNR and SSIM.
same as MATLAB's results
'''
import os
import math
import numpy as np
import cv2
import glob
def main():
# Configurations
# GT - Ground-truth;
# Gen: Generated / Restored / Recovered images
folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5'
folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5'
crop_border = 4
suffix = '' # suffix for Gen images
test_Y = False # True: test Y channel only; False: test RGB channels
PSNR_all = []
SSIM_all = []
img_list = sorted(glob.glob(folder_GT + '/*'))
if test_Y:
print('Testing Y channel.')
else:
print('Testing RGB channels.')
for i, img_path in enumerate(img_list):
base_name = os.path.splitext(os.path.basename(img_path))[0]
im_GT = cv2.imread(img_path) / 255.
im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255.
if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space
im_GT_in = bgr2ycbcr(im_GT)
im_Gen_in = bgr2ycbcr(im_Gen)
else:
im_GT_in = im_GT
im_Gen_in = im_Gen
# crop borders
if im_GT_in.ndim == 3:
cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :]
cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :]
elif im_GT_in.ndim == 2:
cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border]
cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border]
else:
raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim))
# calculate PSNR and SSIM
PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255)
SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255)
print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format(
i + 1, base_name, PSNR, SSIM))
PSNR_all.append(PSNR)
SSIM_all.append(SSIM)
print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format(
sum(PSNR_all) / len(PSNR_all),
sum(SSIM_all) / len(SSIM_all)))
def calculate_psnr(img1, img2):
# img1 and img2 have range [0, 255]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
def ssim(img1, img2):
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_ssim(img1, img2):
'''calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
if img1.ndim == 2:
return ssim(img1, img2)
elif img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1, img2))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError('Wrong input image dimensions.')
def bgr2ycbcr(img, only_y=True):
'''same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
if __name__ == '__main__':
main()

@ -1 +0,0 @@
Subproject commit db2b7899ea8506e90418dbd389300c49bdbb55c3

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from models.archs.srflow_orig import thops from models.srflow_orig import thops
class _ActNorm(nn.Module): class _ActNorm(nn.Module):

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from models.archs.srflow_orig import thops from models.srflow_orig import thops
from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros
from utils.util import opt_get from utils.util import opt_get

View File

@ -2,8 +2,6 @@ import torch
from torch import nn as nn from torch import nn as nn
import models.archs.srflow_orig.Permutations import models.archs.srflow_orig.Permutations
from models.archs.srflow_orig import flow, thops, FlowAffineCouplingsAblation
from utils.util import opt_get
def getConditional(rrdbResults, position): def getConditional(rrdbResults, position):

View File

@ -3,7 +3,8 @@ import torch
from torch import nn as nn from torch import nn as nn
import models.archs.srflow_orig.Split import models.archs.srflow_orig.Split
from models.archs.srflow_orig import flow, thops from models.archs.srflow_orig import flow
from models.srflow_orig import thops
from models.archs.srflow_orig.Split import Split2d from models.archs.srflow_orig.Split import Split2d
from models.archs.srflow_orig.glow_arch import f_conv2d_bias from models.archs.srflow_orig.glow_arch import f_conv2d_bias
from models.archs.srflow_orig.FlowStep import FlowStep from models.archs.srflow_orig.FlowStep import FlowStep

View File

@ -3,7 +3,7 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from models.archs.srflow_orig import thops from models.srflow_orig import thops
class InvertibleConv1x1(nn.Module): class InvertibleConv1x1(nn.Module):

View File

@ -3,11 +3,10 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision
import numpy as np import numpy as np
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
import models.archs.srflow_orig.thops as thops import models.srflow_orig.thops as thops
import models.archs.srflow_orig.flow as flow import models.archs.srflow_orig.flow as flow
from utils.util import opt_get from utils.util import opt_get

View File

@ -1,8 +1,7 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from models.archs.srflow_orig import thops from models.srflow_orig import thops
from models.archs.srflow_orig.FlowStep import FlowStep
from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag
from utils.util import opt_get from utils.util import opt_get

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from models.injectors import Injector from trainer.injectors import Injector
from utils.util import checkpoint from utils.util import checkpoint

View File

@ -11,7 +11,7 @@ import torchvision.transforms.functional as F
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from models.ExtensibleTrainer import ExtensibleTrainer from trainer.ExtensibleTrainer import ExtensibleTrainer
from utils import options as option from utils import options as option
import utils.util as util import utils.util as util
from data import create_dataloader from data import create_dataloader

View File

@ -2,7 +2,6 @@ import argparse
import logging import logging
import math import math
import os import os
import random
from glob import glob from glob import glob
import torch import torch
@ -14,9 +13,8 @@ from tqdm import tqdm
import utils.options as option import utils.options as option
import utils import utils
from data import create_dataset, create_dataloader
from data.image_corruptor import ImageCorruptor from data.image_corruptor import ImageCorruptor
from models.ExtensibleTrainer import ExtensibleTrainer from trainer.ExtensibleTrainer import ExtensibleTrainer
from utils import util from utils import util

View File

@ -1,9 +1,7 @@
import os.path as osp import os.path as osp
import logging import logging
import shutil
import time import time
import argparse import argparse
from collections import OrderedDict
import os import os
@ -12,15 +10,10 @@ import torchvision
import utils import utils
import utils.options as option import utils.options as option
import utils.util as util import utils.util as util
from data.util import bgr2ycbcr from trainer.ExtensibleTrainer import ExtensibleTrainer
import models.archs.SwitchedResidualGenerator_arch as srg
from models.ExtensibleTrainer import ExtensibleTrainer
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
from switched_conv.switched_conv import compute_attention_specificity
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
from tqdm import tqdm from tqdm import tqdm
import torch import torch
import models.networks as networks
if __name__ == "__main__": if __name__ == "__main__":
#### options #### options

View File

@ -5,7 +5,7 @@ import utils
import utils.options as option import utils.options as option
import utils.util as util import utils.util as util
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
from models.ExtensibleTrainer import ExtensibleTrainer from trainer.ExtensibleTrainer import ExtensibleTrainer
class PretrainedImagePatchClassifier: class PretrainedImagePatchClassifier:

View File

@ -6,8 +6,8 @@ import argparse
import os import os
import utils import utils
from models.ExtensibleTrainer import ExtensibleTrainer from trainer.ExtensibleTrainer import ExtensibleTrainer
from models.networks import define_F from trainer.networks import define_F
from utils import options as option from utils import options as option
import utils.util as util import utils.util as util
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader

View File

@ -4,20 +4,13 @@ import time
import argparse import argparse
from collections import OrderedDict from collections import OrderedDict
import os
import utils import utils
import utils.options as option import utils.options as option
import utils.util as util import utils.util as util
from data.util import bgr2ycbcr from trainer.ExtensibleTrainer import ExtensibleTrainer
import models.archs.SwitchedResidualGenerator_arch as srg
from models.ExtensibleTrainer import ExtensibleTrainer
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
from switched_conv.switched_conv import compute_attention_specificity
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
from tqdm import tqdm from tqdm import tqdm
import torch import torch
import models.networks as networks
def forward_pass(model, output_dir, alteration_suffix=''): def forward_pass(model, output_dir, alteration_suffix=''):

View File

@ -2,22 +2,16 @@ import os.path as osp
import logging import logging
import time import time
import argparse import argparse
from collections import OrderedDict
import os import os
import utils import utils
import utils.options as option import utils.options as option
import utils.util as util import utils.util as util
from data.util import bgr2ycbcr from trainer.ExtensibleTrainer import ExtensibleTrainer
import models.archs.SwitchedResidualGenerator_arch as srg
from models.ExtensibleTrainer import ExtensibleTrainer
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
from switched_conv.switched_conv import compute_attention_specificity
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
from tqdm import tqdm from tqdm import tqdm
import torch import torch
import models.networks as networks
import torchvision import torchvision

View File

@ -7,11 +7,11 @@ from tqdm import tqdm
import torch import torch
from data.data_sampler import DistIterSampler from data.data_sampler import DistIterSampler
from models.eval import create_evaluator from trainer.eval import create_evaluator
from utils import util, options as option from utils import util, options as option
from data import create_dataloader, create_dataset from data import create_dataloader, create_dataset
from models.ExtensibleTrainer import ExtensibleTrainer from trainer.ExtensibleTrainer import ExtensibleTrainer
from time import time from time import time
def init_dist(backend, **kwargs): def init_dist(backend, **kwargs):

View File

@ -5,12 +5,12 @@ import torch
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
import torch.nn as nn import torch.nn as nn
import models.lr_scheduler as lr_scheduler import trainer.lr_scheduler as lr_scheduler
import models.networks as networks import trainer.networks as networks
from models.base_model import BaseModel from trainer.base_model import BaseModel
from models.injectors import create_injector from trainer.injectors import create_injector
from models.steps import ConfigurableStep from trainer.steps import ConfigurableStep
from models.experiments.experiments import get_experiment_for_name from trainer.experiments.experiments import get_experiment_for_name
import torchvision.utils as utils import torchvision.utils as utils
logger = logging.getLogger('base') logger = logging.getLogger('base')

View File

View File

@ -6,8 +6,8 @@ import torchvision
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from data.multiscale_dataset import build_multiscale_patch_index_map from data.multiscale_dataset import build_multiscale_patch_index_map
from models.injectors import Injector from trainer.injectors import Injector
from models.losses import extract_params_from_state from trainer.losses import extract_params_from_state
import os.path as osp import os.path as osp
@ -130,7 +130,7 @@ class ProgressiveGeneratorInjector(Injector):
lbl = 'generator_recurrent' lbl = 'generator_recurrent'
else: else:
lbl = 'generator_regular' lbl = 'generator_regular'
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", lbl, str(self.env['step'])) base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", lbl, str(self.env['step']))
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
ind = 1 ind = 1
for i, o in zip(chain_inputs, chain_outputs): for i, o in zip(chain_inputs, chain_outputs):

View File

@ -2,7 +2,7 @@ import torch
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from models.archs.flownet2.networks import Resample2d from models.archs.flownet2.networks import Resample2d
from models.archs.flownet2 import flow2img from models.archs.flownet2 import flow2img
from models.injectors import Injector from trainer.injectors import Injector
def create_stereoscopic_injector(opt, env): def create_stereoscopic_injector(opt, env):

View File

@ -1,9 +1,9 @@
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
from models.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.archs.flownet2.networks import Resample2d from models.archs.flownet2.networks import Resample2d
from models.injectors import Injector from trainer.injectors import Injector
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import os import os
@ -156,7 +156,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it): def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
if self.env['rank'] > 0: if self.env['rank'] > 0:
return return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step'])) base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", "teco_geninput", str(self.env['step']))
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
torchvision.utils.save_image(gen_input.float(), osp.join(base_path, "%s_img.png" % (it,))) torchvision.utils.save_image(gen_input.float(), osp.join(base_path, "%s_img.png" % (it,)))
torchvision.utils.save_image(gen_recurrent.float(), osp.join(base_path, "%s_recurrent.png" % (it,))) torchvision.utils.save_image(gen_recurrent.float(), osp.join(base_path, "%s_recurrent.png" % (it,)))
@ -345,7 +345,7 @@ class TecoGanLoss(ConfigurableLoss):
def produce_teco_visual_debugs(self, sext, lbl, it): def produce_teco_visual_debugs(self, sext, lbl, it):
if self.env['rank'] > 0: if self.env['rank'] > 0:
return return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl) base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c'] lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
for i in range(6): for i in range(6):
@ -378,7 +378,7 @@ class PingPongLoss(ConfigurableLoss):
def produce_teco_visual_debugs(self, imglist): def produce_teco_visual_debugs(self, imglist):
if self.env['rank'] > 0: if self.env['rank'] > 0:
return return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step'])) base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
cnt = imglist.shape[1] cnt = imglist.shape[1]
for i in range(cnt): for i in range(cnt):
@ -388,7 +388,7 @@ class PingPongLoss(ConfigurableLoss):
def produce_teco_visual_debugs2(self, imga, imgb, i): def produce_teco_visual_debugs2(self, imga, imgb, i):
if self.env['rank'] > 0: if self.env['rank'] > 0:
return return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step'])) base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
torchvision.utils.save_image(imga.float(), osp.join(base_path, "%s_a.png" % (i, ))) torchvision.utils.save_image(imga.float(), osp.join(base_path, "%s_a.png" % (i, )))
torchvision.utils.save_image(imgb.float(), osp.join(base_path, "%s_b.png" % (i, ))) torchvision.utils.save_image(imgb.float(), osp.join(base_path, "%s_b.png" % (i, )))

View File

@ -1,6 +1,6 @@
from models.eval.flow_gaussian_nll import FlowGaussianNll from trainer.eval.flow_gaussian_nll import FlowGaussianNll
from models.eval.sr_style import SrStyleTransferEvaluator from trainer.eval.sr_style import SrStyleTransferEvaluator
from models.eval.style import StyleTransferEvaluator from trainer.eval import StyleTransferEvaluator
def create_evaluator(model, opt_eval, env): def create_evaluator(model, opt_eval, env):

View File

@ -1,14 +1,8 @@
import os
import torch import torch
import os.path as osp
import torchvision
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
import models.eval.evaluator as evaluator import trainer.eval.evaluator as evaluator
from pytorch_fid import fid_score
# Evaluate how close to true Gaussian a flow network predicts in a "normal" pass given a LQ/HQ image pair. # Evaluate how close to true Gaussian a flow network predicts in a "normal" pass given a LQ/HQ image pair.
from data.image_folder_dataset import ImageFolderDataset from data.image_folder_dataset import ImageFolderDataset

View File

@ -5,7 +5,7 @@ import os.path as osp
import torchvision import torchvision
from torch.utils.data import BatchSampler from torch.utils.data import BatchSampler
import models.eval.evaluator as evaluator import trainer.eval.evaluator as evaluator
from pytorch_fid import fid_score from pytorch_fid import fid_score
@ -32,9 +32,9 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
def perform_eval(self): def perform_eval(self):
embedding_generator = self.env['generators'][self.embedding_generator] embedding_generator = self.env['generators'][self.embedding_generator]
fid_fake_path = osp.join(self.env['base_path'], "..", "fid_fake", str(self.env["step"])) fid_fake_path = osp.join(self.env['base_path'], "../../models", "fid_fake", str(self.env["step"]))
os.makedirs(fid_fake_path, exist_ok=True) os.makedirs(fid_fake_path, exist_ok=True)
fid_real_path = osp.join(self.env['base_path'], "..", "fid_real", str(self.env["step"])) fid_real_path = osp.join(self.env['base_path'], "../../models", "fid_real", str(self.env["step"]))
os.makedirs(fid_real_path, exist_ok=True) os.makedirs(fid_real_path, exist_ok=True)
counter = 0 counter = 0
for batch in self.sampler: for batch in self.sampler:

View File

@ -3,7 +3,7 @@ import os
import torch import torch
import os.path as osp import os.path as osp
import torchvision import torchvision
import models.eval.evaluator as evaluator import trainer.eval.evaluator as evaluator
from pytorch_fid import fid_score from pytorch_fid import fid_score
@ -18,7 +18,7 @@ class StyleTransferEvaluator(evaluator.Evaluator):
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
def perform_eval(self): def perform_eval(self):
fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"])) fid_fake_path = osp.join(self.env['base_path'], "../../models", "fid", str(self.env["step"]))
os.makedirs(fid_fake_path, exist_ok=True) os.makedirs(fid_fake_path, exist_ok=True)
counter = 0 counter = 0
for i in range(self.batches_per_eval): for i in range(self.batches_per_eval):

View File

View File

@ -3,8 +3,8 @@ from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
import models.networks as networks import trainer.networks as networks
import models.lr_scheduler as lr_scheduler import trainer.lr_scheduler as lr_scheduler
from .base_model import BaseModel from .base_model import BaseModel
logger = logging.getLogger('base') logger = logging.getLogger('base')

View File

@ -4,19 +4,19 @@ import torch.nn
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from utils.weight_scheduler import get_scheduler_for_opt from utils.weight_scheduler import get_scheduler_for_opt
from models.losses import extract_params_from_state from trainer.losses import extract_params_from_state
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions. # Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
def create_injector(opt_inject, env): def create_injector(opt_inject, env):
type = opt_inject['type'] type = opt_inject['type']
if 'teco_' in type: if 'teco_' in type:
from models.custom_training_components import create_teco_injector from trainer.custom_training_components import create_teco_injector
return create_teco_injector(opt_inject, env) return create_teco_injector(opt_inject, env)
elif 'progressive_' in type: elif 'progressive_' in type:
from models.custom_training_components import create_progressive_zoom_injector from trainer.custom_training_components import create_progressive_zoom_injector
return create_progressive_zoom_injector(opt_inject, env) return create_progressive_zoom_injector(opt_inject, env)
elif 'stereoscopic_' in type: elif 'stereoscopic_' in type:
from models.custom_training_components import create_stereoscopic_injector from trainer.custom_training_components import create_stereoscopic_injector
return create_stereoscopic_injector(opt_inject, env) return create_stereoscopic_injector(opt_inject, env)
elif 'igpt' in type: elif 'igpt' in type:
from models.archs.transformers.igpt import gpt2 from models.archs.transformers.igpt import gpt2

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from models.loss import GANLoss from trainer.loss import GANLoss
import random import random
import functools import functools
import torch.nn.functional as F import torch.nn.functional as F
@ -11,7 +11,7 @@ import torch.nn.functional as F
def create_loss(opt_loss, env): def create_loss(opt_loss, env):
type = opt_loss['type'] type = opt_loss['type']
if 'teco_' in type: if 'teco_' in type:
from models.custom_training_components.tecogan_losses import create_teco_loss from trainer.custom_training_components import create_teco_loss
return create_teco_loss(opt_loss, env) return create_teco_loss(opt_loss, env)
elif 'stylegan2_' in type: elif 'stylegan2_' in type:
from models.archs.stylegan import create_stylegan2_loss from models.archs.stylegan import create_stylegan2_loss
@ -152,9 +152,9 @@ class FeatureLoss(ConfigurableLoss):
super(FeatureLoss, self).__init__(opt, env) super(FeatureLoss, self).__init__(opt, env)
self.opt = opt self.opt = opt
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
import models.networks import trainer.networks
self.netF = models.networks.define_F(which_model=opt['which_model_F'], self.netF = trainer.networks.define_F(which_model=opt['which_model_F'],
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device']) load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
if not env['opt']['dist']: if not env['opt']['dist']:
self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids']) self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids'])
@ -178,9 +178,9 @@ class InterpretedFeatureLoss(ConfigurableLoss):
super(InterpretedFeatureLoss, self).__init__(opt, env) super(InterpretedFeatureLoss, self).__init__(opt, env)
self.opt = opt self.opt = opt
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
import models.networks import trainer.networks
self.netF_real = models.networks.define_F(which_model=opt['which_model_F']).to(self.env['device']) self.netF_real = trainer.networks.define_F(which_model=opt['which_model_F']).to(self.env['device'])
self.netF_gen = models.networks.define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device']) self.netF_gen = trainer.networks.define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device'])
if not env['opt']['dist']: if not env['opt']['dist']:
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real) self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen) self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)

View File

@ -3,10 +3,10 @@ from torch.cuda.amp import GradScaler
from utils.loss_accumulator import LossAccumulator from utils.loss_accumulator import LossAccumulator
from torch.nn import Module from torch.nn import Module
import logging import logging
from models.losses import create_loss from trainer.losses import create_loss
import torch import torch
from collections import OrderedDict from collections import OrderedDict
from models.injectors import create_injector from trainer.injectors import create_injector
from utils.util import recursively_detach from utils.util import recursively_detach
logger = logging.getLogger('base') logger = logging.getLogger('base')

View File

@ -2,7 +2,7 @@ import argparse
import functools import functools
import torch import torch
from utils import options as option from utils import options as option
from models.networks import define_G from trainer.networks import define_G
class TracedModule: class TracedModule: