forked from mrq/DL-Art-School
More refactoring
This commit is contained in:
parent
b905b108da
commit
5640e4efe4
8
.gitmodules
vendored
8
.gitmodules
vendored
|
@ -1,13 +1,9 @@
|
|||
[submodule "flownet2"]
|
||||
path = flownet2
|
||||
url = https://github.com/NVIDIA/flownet2-pytorch.git
|
||||
[submodule "codes/switched_conv"]
|
||||
path = codes/switched_conv
|
||||
[submodule "codes/models/switched_conv"]
|
||||
path = codes/models/switched_conv
|
||||
url = https://github.com/neonbjb/SwitchedConvolutions.git
|
||||
[submodule "codes/models/flownet2"]
|
||||
path = codes/models/flownet2
|
||||
url = https://github.com/neonbjb/flownet2-pytorch.git
|
||||
branch = master
|
||||
[submodule "codes/models/archs/flownet2"]
|
||||
path = codes/models/archs/flownet2
|
||||
url = https://github.com/neonbjb/flownet2-pytorch.git
|
||||
|
|
|
@ -1,261 +0,0 @@
|
|||
function calculate_PSNR_SSIM()
|
||||
|
||||
% GT and SR folder
|
||||
folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5';
|
||||
folder_SR = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5';
|
||||
scale = 4;
|
||||
suffix = ''; % suffix for SR images
|
||||
test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels
|
||||
if test_Y
|
||||
fprintf('Tesing Y channel.\n');
|
||||
else
|
||||
fprintf('Tesing RGB channels.\n');
|
||||
end
|
||||
filepaths = dir(fullfile(folder_GT, '*.png'));
|
||||
PSNR_all = zeros(1, length(filepaths));
|
||||
SSIM_all = zeros(1, length(filepaths));
|
||||
|
||||
for idx_im = 1:length(filepaths)
|
||||
im_name = filepaths(idx_im).name;
|
||||
im_GT = imread(fullfile(folder_GT, im_name));
|
||||
im_SR = imread(fullfile(folder_SR, [im_name(1:end-4), suffix, '.png']));
|
||||
|
||||
if test_Y % evaluate on Y channel in YCbCr color space
|
||||
if size(im_GT, 3) == 3
|
||||
im_GT_YCbCr = rgb2ycbcr(im2double(im_GT));
|
||||
im_GT_in = im_GT_YCbCr(:,:,1);
|
||||
im_SR_YCbCr = rgb2ycbcr(im2double(im_SR));
|
||||
im_SR_in = im_SR_YCbCr(:,:,1);
|
||||
else
|
||||
im_GT_in = im2double(im_GT);
|
||||
im_SR_in = im2double(im_SR);
|
||||
end
|
||||
else % evaluate on RGB channels
|
||||
im_GT_in = im2double(im_GT);
|
||||
im_SR_in = im2double(im_SR);
|
||||
end
|
||||
|
||||
% calculate PSNR and SSIM
|
||||
PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale);
|
||||
SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale);
|
||||
fprintf('%d.(X%d)%20s: \tPSNR = %f \tSSIM = %f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im));
|
||||
end
|
||||
|
||||
fprintf('\n%26s: \tPSNR = %f \tSSIM = %f\n', '####Average', mean(PSNR_all), mean(SSIM_all));
|
||||
end
|
||||
|
||||
function res = calculate_PSNR(GT, SR, border)
|
||||
% remove border
|
||||
GT = GT(border+1:end-border, border+1:end-border, :);
|
||||
SR = SR(border+1:end-border, border+1:end-border, :);
|
||||
% calculate PNSR (assume in [0,255])
|
||||
error = GT(:) - SR(:);
|
||||
mse = mean(error.^2);
|
||||
res = 10 * log10(255^2/mse);
|
||||
end
|
||||
|
||||
function res = calculate_SSIM(GT, SR, border)
|
||||
GT = GT(border+1:end-border, border+1:end-border, :);
|
||||
SR = SR(border+1:end-border, border+1:end-border, :);
|
||||
% calculate SSIM
|
||||
mssim = zeros(1, size(SR, 3));
|
||||
for i = 1:size(SR,3)
|
||||
[mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i));
|
||||
end
|
||||
res = mean(mssim);
|
||||
end
|
||||
|
||||
function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L)
|
||||
|
||||
%========================================================================
|
||||
%SSIM Index, Version 1.0
|
||||
%Copyright(c) 2003 Zhou Wang
|
||||
%All Rights Reserved.
|
||||
%
|
||||
%The author is with Howard Hughes Medical Institute, and Laboratory
|
||||
%for Computational Vision at Center for Neural Science and Courant
|
||||
%Institute of Mathematical Sciences, New York University.
|
||||
%
|
||||
%----------------------------------------------------------------------
|
||||
%Permission to use, copy, or modify this software and its documentation
|
||||
%for educational and research purposes only and without fee is hereby
|
||||
%granted, provided that this copyright notice and the original authors'
|
||||
%names appear on all copies and supporting documentation. This program
|
||||
%shall not be used, rewritten, or adapted as the basis of a commercial
|
||||
%software or hardware product without first obtaining permission of the
|
||||
%authors. The authors make no representations about the suitability of
|
||||
%this software for any purpose. It is provided "as is" without express
|
||||
%or implied warranty.
|
||||
%----------------------------------------------------------------------
|
||||
%
|
||||
%This is an implementation of the algorithm for calculating the
|
||||
%Structural SIMilarity (SSIM) index between two images. Please refer
|
||||
%to the following paper:
|
||||
%
|
||||
%Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
|
||||
%quality assessment: From error measurement to structural similarity"
|
||||
%IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
|
||||
%
|
||||
%Kindly report any suggestions or corrections to zhouwang@ieee.org
|
||||
%
|
||||
%----------------------------------------------------------------------
|
||||
%
|
||||
%Input : (1) img1: the first image being compared
|
||||
% (2) img2: the second image being compared
|
||||
% (3) K: constants in the SSIM index formula (see the above
|
||||
% reference). defualt value: K = [0.01 0.03]
|
||||
% (4) window: local window for statistics (see the above
|
||||
% reference). default widnow is Gaussian given by
|
||||
% window = fspecial('gaussian', 11, 1.5);
|
||||
% (5) L: dynamic range of the images. default: L = 255
|
||||
%
|
||||
%Output: (1) mssim: the mean SSIM index value between 2 images.
|
||||
% If one of the images being compared is regarded as
|
||||
% perfect quality, then mssim can be considered as the
|
||||
% quality measure of the other image.
|
||||
% If img1 = img2, then mssim = 1.
|
||||
% (2) ssim_map: the SSIM index map of the test image. The map
|
||||
% has a smaller size than the input images. The actual size:
|
||||
% size(img1) - size(window) + 1.
|
||||
%
|
||||
%Default Usage:
|
||||
% Given 2 test images img1 and img2, whose dynamic range is 0-255
|
||||
%
|
||||
% [mssim ssim_map] = ssim_index(img1, img2);
|
||||
%
|
||||
%Advanced Usage:
|
||||
% User defined parameters. For example
|
||||
%
|
||||
% K = [0.05 0.05];
|
||||
% window = ones(8);
|
||||
% L = 100;
|
||||
% [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
|
||||
%
|
||||
%See the results:
|
||||
%
|
||||
% mssim %Gives the mssim value
|
||||
% imshow(max(0, ssim_map).^4) %Shows the SSIM index map
|
||||
%
|
||||
%========================================================================
|
||||
|
||||
|
||||
if (nargin < 2 || nargin > 5)
|
||||
ssim_index = -Inf;
|
||||
ssim_map = -Inf;
|
||||
return;
|
||||
end
|
||||
|
||||
if (size(img1) ~= size(img2))
|
||||
ssim_index = -Inf;
|
||||
ssim_map = -Inf;
|
||||
return;
|
||||
end
|
||||
|
||||
[M, N] = size(img1);
|
||||
|
||||
if (nargin == 2)
|
||||
if ((M < 11) || (N < 11))
|
||||
ssim_index = -Inf;
|
||||
ssim_map = -Inf;
|
||||
return
|
||||
end
|
||||
window = fspecial('gaussian', 11, 1.5); %
|
||||
K(1) = 0.01; % default settings
|
||||
K(2) = 0.03; %
|
||||
L = 255; %
|
||||
end
|
||||
|
||||
if (nargin == 3)
|
||||
if ((M < 11) || (N < 11))
|
||||
ssim_index = -Inf;
|
||||
ssim_map = -Inf;
|
||||
return
|
||||
end
|
||||
window = fspecial('gaussian', 11, 1.5);
|
||||
L = 255;
|
||||
if (length(K) == 2)
|
||||
if (K(1) < 0 || K(2) < 0)
|
||||
ssim_index = -Inf;
|
||||
ssim_map = -Inf;
|
||||
return;
|
||||
end
|
||||
else
|
||||
ssim_index = -Inf;
|
||||
ssim_map = -Inf;
|
||||
return;
|
||||
end
|
||||
end
|
||||
|
||||
if (nargin == 4)
|
||||
[H, W] = size(window);
|
||||
if ((H*W) < 4 || (H > M) || (W > N))
|
||||
ssim_index = -Inf;
|
||||
ssim_map = -Inf;
|
||||
return
|
||||
end
|
||||
L = 255;
|
||||
if (length(K) == 2)
|
||||
if (K(1) < 0 || K(2) < 0)
|
||||
ssim_index = -Inf;
|
||||
ssim_map = -Inf;
|
||||
return;
|
||||
end
|
||||
else
|
||||
ssim_index = -Inf;
|
||||
ssim_map = -Inf;
|
||||
return;
|
||||
end
|
||||
end
|
||||
|
||||
if (nargin == 5)
|
||||
[H, W] = size(window);
|
||||
if ((H*W) < 4 || (H > M) || (W > N))
|
||||
ssim_index = -Inf;
|
||||
ssim_map = -Inf;
|
||||
return
|
||||
end
|
||||
if (length(K) == 2)
|
||||
if (K(1) < 0 || K(2) < 0)
|
||||
ssim_index = -Inf;
|
||||
ssim_map = -Inf;
|
||||
return;
|
||||
end
|
||||
else
|
||||
ssim_index = -Inf;
|
||||
ssim_map = -Inf;
|
||||
return;
|
||||
end
|
||||
end
|
||||
|
||||
C1 = (K(1)*L)^2;
|
||||
C2 = (K(2)*L)^2;
|
||||
window = window/sum(sum(window));
|
||||
img1 = double(img1);
|
||||
img2 = double(img2);
|
||||
|
||||
mu1 = filter2(window, img1, 'valid');
|
||||
mu2 = filter2(window, img2, 'valid');
|
||||
mu1_sq = mu1.*mu1;
|
||||
mu2_sq = mu2.*mu2;
|
||||
mu1_mu2 = mu1.*mu2;
|
||||
sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
|
||||
sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
|
||||
sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
|
||||
|
||||
if (C1 > 0 && C2 > 0)
|
||||
ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
|
||||
else
|
||||
numerator1 = 2*mu1_mu2 + C1;
|
||||
numerator2 = 2*sigma12 + C2;
|
||||
denominator1 = mu1_sq + mu2_sq + C1;
|
||||
denominator2 = sigma1_sq + sigma2_sq + C2;
|
||||
ssim_map = ones(size(mu1));
|
||||
index = (denominator1.*denominator2 > 0);
|
||||
ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
|
||||
index = (denominator1 ~= 0) & (denominator2 == 0);
|
||||
ssim_map(index) = numerator1(index)./denominator1(index);
|
||||
end
|
||||
|
||||
mssim = mean2(ssim_map);
|
||||
|
||||
end
|
|
@ -1,147 +0,0 @@
|
|||
'''
|
||||
calculate the PSNR and SSIM.
|
||||
same as MATLAB's results
|
||||
'''
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
import glob
|
||||
|
||||
|
||||
def main():
|
||||
# Configurations
|
||||
|
||||
# GT - Ground-truth;
|
||||
# Gen: Generated / Restored / Recovered images
|
||||
folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5'
|
||||
folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5'
|
||||
|
||||
crop_border = 4
|
||||
suffix = '' # suffix for Gen images
|
||||
test_Y = False # True: test Y channel only; False: test RGB channels
|
||||
|
||||
PSNR_all = []
|
||||
SSIM_all = []
|
||||
img_list = sorted(glob.glob(folder_GT + '/*'))
|
||||
|
||||
if test_Y:
|
||||
print('Testing Y channel.')
|
||||
else:
|
||||
print('Testing RGB channels.')
|
||||
|
||||
for i, img_path in enumerate(img_list):
|
||||
base_name = os.path.splitext(os.path.basename(img_path))[0]
|
||||
im_GT = cv2.imread(img_path) / 255.
|
||||
im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255.
|
||||
|
||||
if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space
|
||||
im_GT_in = bgr2ycbcr(im_GT)
|
||||
im_Gen_in = bgr2ycbcr(im_Gen)
|
||||
else:
|
||||
im_GT_in = im_GT
|
||||
im_Gen_in = im_Gen
|
||||
|
||||
# crop borders
|
||||
if im_GT_in.ndim == 3:
|
||||
cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :]
|
||||
cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :]
|
||||
elif im_GT_in.ndim == 2:
|
||||
cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border]
|
||||
cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border]
|
||||
else:
|
||||
raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim))
|
||||
|
||||
# calculate PSNR and SSIM
|
||||
PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255)
|
||||
|
||||
SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255)
|
||||
print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format(
|
||||
i + 1, base_name, PSNR, SSIM))
|
||||
PSNR_all.append(PSNR)
|
||||
SSIM_all.append(SSIM)
|
||||
print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format(
|
||||
sum(PSNR_all) / len(PSNR_all),
|
||||
sum(SSIM_all) / len(SSIM_all)))
|
||||
|
||||
|
||||
def calculate_psnr(img1, img2):
|
||||
# img1 and img2 have range [0, 255]
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
mse = np.mean((img1 - img2)**2)
|
||||
if mse == 0:
|
||||
return float('inf')
|
||||
return 20 * math.log10(255.0 / math.sqrt(mse))
|
||||
|
||||
|
||||
def ssim(img1, img2):
|
||||
C1 = (0.01 * 255)**2
|
||||
C2 = (0.03 * 255)**2
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1**2
|
||||
mu2_sq = mu2**2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
def calculate_ssim(img1, img2):
|
||||
'''calculate SSIM
|
||||
the same outputs as MATLAB's
|
||||
img1, img2: [0, 255]
|
||||
'''
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
if img1.ndim == 2:
|
||||
return ssim(img1, img2)
|
||||
elif img1.ndim == 3:
|
||||
if img1.shape[2] == 3:
|
||||
ssims = []
|
||||
for i in range(3):
|
||||
ssims.append(ssim(img1, img2))
|
||||
return np.array(ssims).mean()
|
||||
elif img1.shape[2] == 1:
|
||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||
else:
|
||||
raise ValueError('Wrong input image dimensions.')
|
||||
|
||||
|
||||
def bgr2ycbcr(img, only_y=True):
|
||||
'''same as matlab rgb2ycbcr
|
||||
only_y: only return Y channel
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
'''
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.
|
||||
# convert
|
||||
if only_y:
|
||||
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
||||
else:
|
||||
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
||||
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1 +0,0 @@
|
|||
Subproject commit db2b7899ea8506e90418dbd389300c49bdbb55c3
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.archs.srflow_orig import thops
|
||||
from models.srflow_orig import thops
|
||||
|
||||
|
||||
class _ActNorm(nn.Module):
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.archs.srflow_orig import thops
|
||||
from models.srflow_orig import thops
|
||||
from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros
|
||||
from utils.util import opt_get
|
||||
|
|
@ -2,8 +2,6 @@ import torch
|
|||
from torch import nn as nn
|
||||
|
||||
import models.archs.srflow_orig.Permutations
|
||||
from models.archs.srflow_orig import flow, thops, FlowAffineCouplingsAblation
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
def getConditional(rrdbResults, position):
|
|
@ -3,7 +3,8 @@ import torch
|
|||
from torch import nn as nn
|
||||
|
||||
import models.archs.srflow_orig.Split
|
||||
from models.archs.srflow_orig import flow, thops
|
||||
from models.archs.srflow_orig import flow
|
||||
from models.srflow_orig import thops
|
||||
from models.archs.srflow_orig.Split import Split2d
|
||||
from models.archs.srflow_orig.glow_arch import f_conv2d_bias
|
||||
from models.archs.srflow_orig.FlowStep import FlowStep
|
|
@ -3,7 +3,7 @@ import torch
|
|||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from models.archs.srflow_orig import thops
|
||||
from models.srflow_orig import thops
|
||||
|
||||
|
||||
class InvertibleConv1x1(nn.Module):
|
|
@ -3,11 +3,10 @@ import math
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import numpy as np
|
||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
||||
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
|
||||
import models.archs.srflow_orig.thops as thops
|
||||
import models.srflow_orig.thops as thops
|
||||
import models.archs.srflow_orig.flow as flow
|
||||
from utils.util import opt_get
|
||||
|
|
@ -1,8 +1,7 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.archs.srflow_orig import thops
|
||||
from models.archs.srflow_orig.FlowStep import FlowStep
|
||||
from models.srflow_orig import thops
|
||||
from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag
|
||||
from utils.util import opt_get
|
||||
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from models.injectors import Injector
|
||||
from trainer.injectors import Injector
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
|
@ -11,7 +11,7 @@ import torchvision.transforms.functional as F
|
|||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from utils import options as option
|
||||
import utils.util as util
|
||||
from data import create_dataloader
|
||||
|
|
|
@ -2,7 +2,6 @@ import argparse
|
|||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from glob import glob
|
||||
|
||||
import torch
|
||||
|
@ -14,9 +13,8 @@ from tqdm import tqdm
|
|||
import utils.options as option
|
||||
|
||||
import utils
|
||||
from data import create_dataset, create_dataloader
|
||||
from data.image_corruptor import ImageCorruptor
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from utils import util
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import os.path as osp
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
import os
|
||||
|
||||
|
@ -12,15 +10,10 @@ import torchvision
|
|||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
from data.util import bgr2ycbcr
|
||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
||||
from switched_conv.switched_conv import compute_attention_specificity
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from data import create_dataset, create_dataloader
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import models.networks as networks
|
||||
|
||||
if __name__ == "__main__":
|
||||
#### options
|
||||
|
|
|
@ -5,7 +5,7 @@ import utils
|
|||
import utils.options as option
|
||||
import utils.util as util
|
||||
from data import create_dataset, create_dataloader
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
|
||||
|
||||
class PretrainedImagePatchClassifier:
|
||||
|
|
|
@ -6,8 +6,8 @@ import argparse
|
|||
import os
|
||||
|
||||
import utils
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from models.networks import define_F
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from trainer.networks import define_F
|
||||
from utils import options as option
|
||||
import utils.util as util
|
||||
from data import create_dataset, create_dataloader
|
||||
|
|
|
@ -4,20 +4,13 @@ import time
|
|||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
import os
|
||||
|
||||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
from data.util import bgr2ycbcr
|
||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
||||
from switched_conv.switched_conv import compute_attention_specificity
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from data import create_dataset, create_dataloader
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import models.networks as networks
|
||||
|
||||
|
||||
def forward_pass(model, output_dir, alteration_suffix=''):
|
||||
|
|
|
@ -2,22 +2,16 @@ import os.path as osp
|
|||
import logging
|
||||
import time
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
import os
|
||||
|
||||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
from data.util import bgr2ycbcr
|
||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
||||
from switched_conv.switched_conv import compute_attention_specificity
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from data import create_dataset, create_dataloader
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import models.networks as networks
|
||||
import torchvision
|
||||
|
||||
|
||||
|
|
|
@ -7,11 +7,11 @@ from tqdm import tqdm
|
|||
|
||||
import torch
|
||||
from data.data_sampler import DistIterSampler
|
||||
from models.eval import create_evaluator
|
||||
from trainer.eval import create_evaluator
|
||||
|
||||
from utils import util, options as option
|
||||
from data import create_dataloader, create_dataset
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from time import time
|
||||
|
||||
def init_dist(backend, **kwargs):
|
||||
|
|
|
@ -5,12 +5,12 @@ import torch
|
|||
from torch.nn.parallel import DataParallel
|
||||
import torch.nn as nn
|
||||
|
||||
import models.lr_scheduler as lr_scheduler
|
||||
import models.networks as networks
|
||||
from models.base_model import BaseModel
|
||||
from models.injectors import create_injector
|
||||
from models.steps import ConfigurableStep
|
||||
from models.experiments.experiments import get_experiment_for_name
|
||||
import trainer.lr_scheduler as lr_scheduler
|
||||
import trainer.networks as networks
|
||||
from trainer.base_model import BaseModel
|
||||
from trainer.injectors import create_injector
|
||||
from trainer.steps import ConfigurableStep
|
||||
from trainer.experiments.experiments import get_experiment_for_name
|
||||
import torchvision.utils as utils
|
||||
|
||||
logger = logging.getLogger('base')
|
0
codes/trainer/__init__.py
Normal file
0
codes/trainer/__init__.py
Normal file
|
@ -6,8 +6,8 @@ import torchvision
|
|||
from torch.cuda.amp import autocast
|
||||
|
||||
from data.multiscale_dataset import build_multiscale_patch_index_map
|
||||
from models.injectors import Injector
|
||||
from models.losses import extract_params_from_state
|
||||
from trainer.injectors import Injector
|
||||
from trainer.losses import extract_params_from_state
|
||||
import os.path as osp
|
||||
|
||||
|
||||
|
@ -130,7 +130,7 @@ class ProgressiveGeneratorInjector(Injector):
|
|||
lbl = 'generator_recurrent'
|
||||
else:
|
||||
lbl = 'generator_regular'
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", lbl, str(self.env['step']))
|
||||
base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", lbl, str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
ind = 1
|
||||
for i, o in zip(chain_inputs, chain_outputs):
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch.cuda.amp import autocast
|
||||
from models.archs.flownet2.networks import Resample2d
|
||||
from models.archs.flownet2 import flow2img
|
||||
from models.injectors import Injector
|
||||
from trainer.injectors import Injector
|
||||
|
||||
|
||||
def create_stereoscopic_injector(opt, env):
|
|
@ -1,9 +1,9 @@
|
|||
from torch.cuda.amp import autocast
|
||||
|
||||
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
from models.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||
from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||
from models.archs.flownet2.networks import Resample2d
|
||||
from models.injectors import Injector
|
||||
from trainer.injectors import Injector
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import os
|
||||
|
@ -156,7 +156,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
|
||||
if self.env['rank'] > 0:
|
||||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
||||
base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", "teco_geninput", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
torchvision.utils.save_image(gen_input.float(), osp.join(base_path, "%s_img.png" % (it,)))
|
||||
torchvision.utils.save_image(gen_recurrent.float(), osp.join(base_path, "%s_recurrent.png" % (it,)))
|
||||
|
@ -345,7 +345,7 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||
if self.env['rank'] > 0:
|
||||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
||||
base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
|
||||
for i in range(6):
|
||||
|
@ -378,7 +378,7 @@ class PingPongLoss(ConfigurableLoss):
|
|||
def produce_teco_visual_debugs(self, imglist):
|
||||
if self.env['rank'] > 0:
|
||||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||
base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
cnt = imglist.shape[1]
|
||||
for i in range(cnt):
|
||||
|
@ -388,7 +388,7 @@ class PingPongLoss(ConfigurableLoss):
|
|||
def produce_teco_visual_debugs2(self, imga, imgb, i):
|
||||
if self.env['rank'] > 0:
|
||||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||
base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
torchvision.utils.save_image(imga.float(), osp.join(base_path, "%s_a.png" % (i, )))
|
||||
torchvision.utils.save_image(imgb.float(), osp.join(base_path, "%s_b.png" % (i, )))
|
|
@ -1,6 +1,6 @@
|
|||
from models.eval.flow_gaussian_nll import FlowGaussianNll
|
||||
from models.eval.sr_style import SrStyleTransferEvaluator
|
||||
from models.eval.style import StyleTransferEvaluator
|
||||
from trainer.eval.flow_gaussian_nll import FlowGaussianNll
|
||||
from trainer.eval.sr_style import SrStyleTransferEvaluator
|
||||
from trainer.eval import StyleTransferEvaluator
|
||||
|
||||
|
||||
def create_evaluator(model, opt_eval, env):
|
|
@ -1,14 +1,8 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import os.path as osp
|
||||
import torchvision
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
import models.eval.evaluator as evaluator
|
||||
from pytorch_fid import fid_score
|
||||
|
||||
import trainer.eval.evaluator as evaluator
|
||||
|
||||
# Evaluate how close to true Gaussian a flow network predicts in a "normal" pass given a LQ/HQ image pair.
|
||||
from data.image_folder_dataset import ImageFolderDataset
|
|
@ -5,7 +5,7 @@ import os.path as osp
|
|||
import torchvision
|
||||
from torch.utils.data import BatchSampler
|
||||
|
||||
import models.eval.evaluator as evaluator
|
||||
import trainer.eval.evaluator as evaluator
|
||||
from pytorch_fid import fid_score
|
||||
|
||||
|
||||
|
@ -32,9 +32,9 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
|
|||
|
||||
def perform_eval(self):
|
||||
embedding_generator = self.env['generators'][self.embedding_generator]
|
||||
fid_fake_path = osp.join(self.env['base_path'], "..", "fid_fake", str(self.env["step"]))
|
||||
fid_fake_path = osp.join(self.env['base_path'], "../../models", "fid_fake", str(self.env["step"]))
|
||||
os.makedirs(fid_fake_path, exist_ok=True)
|
||||
fid_real_path = osp.join(self.env['base_path'], "..", "fid_real", str(self.env["step"]))
|
||||
fid_real_path = osp.join(self.env['base_path'], "../../models", "fid_real", str(self.env["step"]))
|
||||
os.makedirs(fid_real_path, exist_ok=True)
|
||||
counter = 0
|
||||
for batch in self.sampler:
|
|
@ -3,7 +3,7 @@ import os
|
|||
import torch
|
||||
import os.path as osp
|
||||
import torchvision
|
||||
import models.eval.evaluator as evaluator
|
||||
import trainer.eval.evaluator as evaluator
|
||||
from pytorch_fid import fid_score
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ class StyleTransferEvaluator(evaluator.Evaluator):
|
|||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||
|
||||
def perform_eval(self):
|
||||
fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
|
||||
fid_fake_path = osp.join(self.env['base_path'], "../../models", "fid", str(self.env["step"]))
|
||||
os.makedirs(fid_fake_path, exist_ok=True)
|
||||
counter = 0
|
||||
for i in range(self.batches_per_eval):
|
0
codes/trainer/experiments/__init__.py
Normal file
0
codes/trainer/experiments/__init__.py
Normal file
|
@ -3,8 +3,8 @@ from collections import OrderedDict
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import models.networks as networks
|
||||
import models.lr_scheduler as lr_scheduler
|
||||
import trainer.networks as networks
|
||||
import trainer.lr_scheduler as lr_scheduler
|
||||
from .base_model import BaseModel
|
||||
|
||||
logger = logging.getLogger('base')
|
|
@ -4,19 +4,19 @@ import torch.nn
|
|||
from torch.cuda.amp import autocast
|
||||
|
||||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
from models.losses import extract_params_from_state
|
||||
from trainer.losses import extract_params_from_state
|
||||
|
||||
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
||||
def create_injector(opt_inject, env):
|
||||
type = opt_inject['type']
|
||||
if 'teco_' in type:
|
||||
from models.custom_training_components import create_teco_injector
|
||||
from trainer.custom_training_components import create_teco_injector
|
||||
return create_teco_injector(opt_inject, env)
|
||||
elif 'progressive_' in type:
|
||||
from models.custom_training_components import create_progressive_zoom_injector
|
||||
from trainer.custom_training_components import create_progressive_zoom_injector
|
||||
return create_progressive_zoom_injector(opt_inject, env)
|
||||
elif 'stereoscopic_' in type:
|
||||
from models.custom_training_components import create_stereoscopic_injector
|
||||
from trainer.custom_training_components import create_stereoscopic_injector
|
||||
return create_stereoscopic_injector(opt_inject, env)
|
||||
elif 'igpt' in type:
|
||||
from models.archs.transformers.igpt import gpt2
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from models.loss import GANLoss
|
||||
from trainer.loss import GANLoss
|
||||
import random
|
||||
import functools
|
||||
import torch.nn.functional as F
|
||||
|
@ -11,7 +11,7 @@ import torch.nn.functional as F
|
|||
def create_loss(opt_loss, env):
|
||||
type = opt_loss['type']
|
||||
if 'teco_' in type:
|
||||
from models.custom_training_components.tecogan_losses import create_teco_loss
|
||||
from trainer.custom_training_components import create_teco_loss
|
||||
return create_teco_loss(opt_loss, env)
|
||||
elif 'stylegan2_' in type:
|
||||
from models.archs.stylegan import create_stylegan2_loss
|
||||
|
@ -152,9 +152,9 @@ class FeatureLoss(ConfigurableLoss):
|
|||
super(FeatureLoss, self).__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||
import models.networks
|
||||
self.netF = models.networks.define_F(which_model=opt['which_model_F'],
|
||||
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
|
||||
import trainer.networks
|
||||
self.netF = trainer.networks.define_F(which_model=opt['which_model_F'],
|
||||
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
|
||||
if not env['opt']['dist']:
|
||||
self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids'])
|
||||
|
||||
|
@ -178,9 +178,9 @@ class InterpretedFeatureLoss(ConfigurableLoss):
|
|||
super(InterpretedFeatureLoss, self).__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||
import models.networks
|
||||
self.netF_real = models.networks.define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
||||
self.netF_gen = models.networks.define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device'])
|
||||
import trainer.networks
|
||||
self.netF_real = trainer.networks.define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
||||
self.netF_gen = trainer.networks.define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device'])
|
||||
if not env['opt']['dist']:
|
||||
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
|
||||
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
|
@ -3,10 +3,10 @@ from torch.cuda.amp import GradScaler
|
|||
from utils.loss_accumulator import LossAccumulator
|
||||
from torch.nn import Module
|
||||
import logging
|
||||
from models.losses import create_loss
|
||||
from trainer.losses import create_loss
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from models.injectors import create_injector
|
||||
from trainer.injectors import create_injector
|
||||
from utils.util import recursively_detach
|
||||
|
||||
logger = logging.getLogger('base')
|
|
@ -2,7 +2,7 @@ import argparse
|
|||
import functools
|
||||
import torch
|
||||
from utils import options as option
|
||||
from models.networks import define_G
|
||||
from trainer.networks import define_G
|
||||
|
||||
|
||||
class TracedModule:
|
||||
|
|
Loading…
Reference in New Issue
Block a user