forked from mrq/DL-Art-School
More refactoring
This commit is contained in:
parent
b905b108da
commit
5640e4efe4
8
.gitmodules
vendored
8
.gitmodules
vendored
|
@ -1,13 +1,9 @@
|
||||||
[submodule "flownet2"]
|
[submodule "flownet2"]
|
||||||
path = flownet2
|
path = flownet2
|
||||||
url = https://github.com/NVIDIA/flownet2-pytorch.git
|
url = https://github.com/NVIDIA/flownet2-pytorch.git
|
||||||
[submodule "codes/switched_conv"]
|
[submodule "codes/models/switched_conv"]
|
||||||
path = codes/switched_conv
|
path = codes/models/switched_conv
|
||||||
url = https://github.com/neonbjb/SwitchedConvolutions.git
|
url = https://github.com/neonbjb/SwitchedConvolutions.git
|
||||||
[submodule "codes/models/flownet2"]
|
[submodule "codes/models/flownet2"]
|
||||||
path = codes/models/flownet2
|
path = codes/models/flownet2
|
||||||
url = https://github.com/neonbjb/flownet2-pytorch.git
|
url = https://github.com/neonbjb/flownet2-pytorch.git
|
||||||
branch = master
|
|
||||||
[submodule "codes/models/archs/flownet2"]
|
|
||||||
path = codes/models/archs/flownet2
|
|
||||||
url = https://github.com/neonbjb/flownet2-pytorch.git
|
|
||||||
|
|
|
@ -1,261 +0,0 @@
|
||||||
function calculate_PSNR_SSIM()
|
|
||||||
|
|
||||||
% GT and SR folder
|
|
||||||
folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5';
|
|
||||||
folder_SR = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5';
|
|
||||||
scale = 4;
|
|
||||||
suffix = ''; % suffix for SR images
|
|
||||||
test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels
|
|
||||||
if test_Y
|
|
||||||
fprintf('Tesing Y channel.\n');
|
|
||||||
else
|
|
||||||
fprintf('Tesing RGB channels.\n');
|
|
||||||
end
|
|
||||||
filepaths = dir(fullfile(folder_GT, '*.png'));
|
|
||||||
PSNR_all = zeros(1, length(filepaths));
|
|
||||||
SSIM_all = zeros(1, length(filepaths));
|
|
||||||
|
|
||||||
for idx_im = 1:length(filepaths)
|
|
||||||
im_name = filepaths(idx_im).name;
|
|
||||||
im_GT = imread(fullfile(folder_GT, im_name));
|
|
||||||
im_SR = imread(fullfile(folder_SR, [im_name(1:end-4), suffix, '.png']));
|
|
||||||
|
|
||||||
if test_Y % evaluate on Y channel in YCbCr color space
|
|
||||||
if size(im_GT, 3) == 3
|
|
||||||
im_GT_YCbCr = rgb2ycbcr(im2double(im_GT));
|
|
||||||
im_GT_in = im_GT_YCbCr(:,:,1);
|
|
||||||
im_SR_YCbCr = rgb2ycbcr(im2double(im_SR));
|
|
||||||
im_SR_in = im_SR_YCbCr(:,:,1);
|
|
||||||
else
|
|
||||||
im_GT_in = im2double(im_GT);
|
|
||||||
im_SR_in = im2double(im_SR);
|
|
||||||
end
|
|
||||||
else % evaluate on RGB channels
|
|
||||||
im_GT_in = im2double(im_GT);
|
|
||||||
im_SR_in = im2double(im_SR);
|
|
||||||
end
|
|
||||||
|
|
||||||
% calculate PSNR and SSIM
|
|
||||||
PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale);
|
|
||||||
SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale);
|
|
||||||
fprintf('%d.(X%d)%20s: \tPSNR = %f \tSSIM = %f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im));
|
|
||||||
end
|
|
||||||
|
|
||||||
fprintf('\n%26s: \tPSNR = %f \tSSIM = %f\n', '####Average', mean(PSNR_all), mean(SSIM_all));
|
|
||||||
end
|
|
||||||
|
|
||||||
function res = calculate_PSNR(GT, SR, border)
|
|
||||||
% remove border
|
|
||||||
GT = GT(border+1:end-border, border+1:end-border, :);
|
|
||||||
SR = SR(border+1:end-border, border+1:end-border, :);
|
|
||||||
% calculate PNSR (assume in [0,255])
|
|
||||||
error = GT(:) - SR(:);
|
|
||||||
mse = mean(error.^2);
|
|
||||||
res = 10 * log10(255^2/mse);
|
|
||||||
end
|
|
||||||
|
|
||||||
function res = calculate_SSIM(GT, SR, border)
|
|
||||||
GT = GT(border+1:end-border, border+1:end-border, :);
|
|
||||||
SR = SR(border+1:end-border, border+1:end-border, :);
|
|
||||||
% calculate SSIM
|
|
||||||
mssim = zeros(1, size(SR, 3));
|
|
||||||
for i = 1:size(SR,3)
|
|
||||||
[mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i));
|
|
||||||
end
|
|
||||||
res = mean(mssim);
|
|
||||||
end
|
|
||||||
|
|
||||||
function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L)
|
|
||||||
|
|
||||||
%========================================================================
|
|
||||||
%SSIM Index, Version 1.0
|
|
||||||
%Copyright(c) 2003 Zhou Wang
|
|
||||||
%All Rights Reserved.
|
|
||||||
%
|
|
||||||
%The author is with Howard Hughes Medical Institute, and Laboratory
|
|
||||||
%for Computational Vision at Center for Neural Science and Courant
|
|
||||||
%Institute of Mathematical Sciences, New York University.
|
|
||||||
%
|
|
||||||
%----------------------------------------------------------------------
|
|
||||||
%Permission to use, copy, or modify this software and its documentation
|
|
||||||
%for educational and research purposes only and without fee is hereby
|
|
||||||
%granted, provided that this copyright notice and the original authors'
|
|
||||||
%names appear on all copies and supporting documentation. This program
|
|
||||||
%shall not be used, rewritten, or adapted as the basis of a commercial
|
|
||||||
%software or hardware product without first obtaining permission of the
|
|
||||||
%authors. The authors make no representations about the suitability of
|
|
||||||
%this software for any purpose. It is provided "as is" without express
|
|
||||||
%or implied warranty.
|
|
||||||
%----------------------------------------------------------------------
|
|
||||||
%
|
|
||||||
%This is an implementation of the algorithm for calculating the
|
|
||||||
%Structural SIMilarity (SSIM) index between two images. Please refer
|
|
||||||
%to the following paper:
|
|
||||||
%
|
|
||||||
%Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
|
|
||||||
%quality assessment: From error measurement to structural similarity"
|
|
||||||
%IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
|
|
||||||
%
|
|
||||||
%Kindly report any suggestions or corrections to zhouwang@ieee.org
|
|
||||||
%
|
|
||||||
%----------------------------------------------------------------------
|
|
||||||
%
|
|
||||||
%Input : (1) img1: the first image being compared
|
|
||||||
% (2) img2: the second image being compared
|
|
||||||
% (3) K: constants in the SSIM index formula (see the above
|
|
||||||
% reference). defualt value: K = [0.01 0.03]
|
|
||||||
% (4) window: local window for statistics (see the above
|
|
||||||
% reference). default widnow is Gaussian given by
|
|
||||||
% window = fspecial('gaussian', 11, 1.5);
|
|
||||||
% (5) L: dynamic range of the images. default: L = 255
|
|
||||||
%
|
|
||||||
%Output: (1) mssim: the mean SSIM index value between 2 images.
|
|
||||||
% If one of the images being compared is regarded as
|
|
||||||
% perfect quality, then mssim can be considered as the
|
|
||||||
% quality measure of the other image.
|
|
||||||
% If img1 = img2, then mssim = 1.
|
|
||||||
% (2) ssim_map: the SSIM index map of the test image. The map
|
|
||||||
% has a smaller size than the input images. The actual size:
|
|
||||||
% size(img1) - size(window) + 1.
|
|
||||||
%
|
|
||||||
%Default Usage:
|
|
||||||
% Given 2 test images img1 and img2, whose dynamic range is 0-255
|
|
||||||
%
|
|
||||||
% [mssim ssim_map] = ssim_index(img1, img2);
|
|
||||||
%
|
|
||||||
%Advanced Usage:
|
|
||||||
% User defined parameters. For example
|
|
||||||
%
|
|
||||||
% K = [0.05 0.05];
|
|
||||||
% window = ones(8);
|
|
||||||
% L = 100;
|
|
||||||
% [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
|
|
||||||
%
|
|
||||||
%See the results:
|
|
||||||
%
|
|
||||||
% mssim %Gives the mssim value
|
|
||||||
% imshow(max(0, ssim_map).^4) %Shows the SSIM index map
|
|
||||||
%
|
|
||||||
%========================================================================
|
|
||||||
|
|
||||||
|
|
||||||
if (nargin < 2 || nargin > 5)
|
|
||||||
ssim_index = -Inf;
|
|
||||||
ssim_map = -Inf;
|
|
||||||
return;
|
|
||||||
end
|
|
||||||
|
|
||||||
if (size(img1) ~= size(img2))
|
|
||||||
ssim_index = -Inf;
|
|
||||||
ssim_map = -Inf;
|
|
||||||
return;
|
|
||||||
end
|
|
||||||
|
|
||||||
[M, N] = size(img1);
|
|
||||||
|
|
||||||
if (nargin == 2)
|
|
||||||
if ((M < 11) || (N < 11))
|
|
||||||
ssim_index = -Inf;
|
|
||||||
ssim_map = -Inf;
|
|
||||||
return
|
|
||||||
end
|
|
||||||
window = fspecial('gaussian', 11, 1.5); %
|
|
||||||
K(1) = 0.01; % default settings
|
|
||||||
K(2) = 0.03; %
|
|
||||||
L = 255; %
|
|
||||||
end
|
|
||||||
|
|
||||||
if (nargin == 3)
|
|
||||||
if ((M < 11) || (N < 11))
|
|
||||||
ssim_index = -Inf;
|
|
||||||
ssim_map = -Inf;
|
|
||||||
return
|
|
||||||
end
|
|
||||||
window = fspecial('gaussian', 11, 1.5);
|
|
||||||
L = 255;
|
|
||||||
if (length(K) == 2)
|
|
||||||
if (K(1) < 0 || K(2) < 0)
|
|
||||||
ssim_index = -Inf;
|
|
||||||
ssim_map = -Inf;
|
|
||||||
return;
|
|
||||||
end
|
|
||||||
else
|
|
||||||
ssim_index = -Inf;
|
|
||||||
ssim_map = -Inf;
|
|
||||||
return;
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
if (nargin == 4)
|
|
||||||
[H, W] = size(window);
|
|
||||||
if ((H*W) < 4 || (H > M) || (W > N))
|
|
||||||
ssim_index = -Inf;
|
|
||||||
ssim_map = -Inf;
|
|
||||||
return
|
|
||||||
end
|
|
||||||
L = 255;
|
|
||||||
if (length(K) == 2)
|
|
||||||
if (K(1) < 0 || K(2) < 0)
|
|
||||||
ssim_index = -Inf;
|
|
||||||
ssim_map = -Inf;
|
|
||||||
return;
|
|
||||||
end
|
|
||||||
else
|
|
||||||
ssim_index = -Inf;
|
|
||||||
ssim_map = -Inf;
|
|
||||||
return;
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
if (nargin == 5)
|
|
||||||
[H, W] = size(window);
|
|
||||||
if ((H*W) < 4 || (H > M) || (W > N))
|
|
||||||
ssim_index = -Inf;
|
|
||||||
ssim_map = -Inf;
|
|
||||||
return
|
|
||||||
end
|
|
||||||
if (length(K) == 2)
|
|
||||||
if (K(1) < 0 || K(2) < 0)
|
|
||||||
ssim_index = -Inf;
|
|
||||||
ssim_map = -Inf;
|
|
||||||
return;
|
|
||||||
end
|
|
||||||
else
|
|
||||||
ssim_index = -Inf;
|
|
||||||
ssim_map = -Inf;
|
|
||||||
return;
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
C1 = (K(1)*L)^2;
|
|
||||||
C2 = (K(2)*L)^2;
|
|
||||||
window = window/sum(sum(window));
|
|
||||||
img1 = double(img1);
|
|
||||||
img2 = double(img2);
|
|
||||||
|
|
||||||
mu1 = filter2(window, img1, 'valid');
|
|
||||||
mu2 = filter2(window, img2, 'valid');
|
|
||||||
mu1_sq = mu1.*mu1;
|
|
||||||
mu2_sq = mu2.*mu2;
|
|
||||||
mu1_mu2 = mu1.*mu2;
|
|
||||||
sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
|
|
||||||
sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
|
|
||||||
sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
|
|
||||||
|
|
||||||
if (C1 > 0 && C2 > 0)
|
|
||||||
ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
|
|
||||||
else
|
|
||||||
numerator1 = 2*mu1_mu2 + C1;
|
|
||||||
numerator2 = 2*sigma12 + C2;
|
|
||||||
denominator1 = mu1_sq + mu2_sq + C1;
|
|
||||||
denominator2 = sigma1_sq + sigma2_sq + C2;
|
|
||||||
ssim_map = ones(size(mu1));
|
|
||||||
index = (denominator1.*denominator2 > 0);
|
|
||||||
ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
|
|
||||||
index = (denominator1 ~= 0) & (denominator2 == 0);
|
|
||||||
ssim_map(index) = numerator1(index)./denominator1(index);
|
|
||||||
end
|
|
||||||
|
|
||||||
mssim = mean2(ssim_map);
|
|
||||||
|
|
||||||
end
|
|
|
@ -1,147 +0,0 @@
|
||||||
'''
|
|
||||||
calculate the PSNR and SSIM.
|
|
||||||
same as MATLAB's results
|
|
||||||
'''
|
|
||||||
import os
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import glob
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Configurations
|
|
||||||
|
|
||||||
# GT - Ground-truth;
|
|
||||||
# Gen: Generated / Restored / Recovered images
|
|
||||||
folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5'
|
|
||||||
folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5'
|
|
||||||
|
|
||||||
crop_border = 4
|
|
||||||
suffix = '' # suffix for Gen images
|
|
||||||
test_Y = False # True: test Y channel only; False: test RGB channels
|
|
||||||
|
|
||||||
PSNR_all = []
|
|
||||||
SSIM_all = []
|
|
||||||
img_list = sorted(glob.glob(folder_GT + '/*'))
|
|
||||||
|
|
||||||
if test_Y:
|
|
||||||
print('Testing Y channel.')
|
|
||||||
else:
|
|
||||||
print('Testing RGB channels.')
|
|
||||||
|
|
||||||
for i, img_path in enumerate(img_list):
|
|
||||||
base_name = os.path.splitext(os.path.basename(img_path))[0]
|
|
||||||
im_GT = cv2.imread(img_path) / 255.
|
|
||||||
im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255.
|
|
||||||
|
|
||||||
if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space
|
|
||||||
im_GT_in = bgr2ycbcr(im_GT)
|
|
||||||
im_Gen_in = bgr2ycbcr(im_Gen)
|
|
||||||
else:
|
|
||||||
im_GT_in = im_GT
|
|
||||||
im_Gen_in = im_Gen
|
|
||||||
|
|
||||||
# crop borders
|
|
||||||
if im_GT_in.ndim == 3:
|
|
||||||
cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :]
|
|
||||||
cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :]
|
|
||||||
elif im_GT_in.ndim == 2:
|
|
||||||
cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border]
|
|
||||||
cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border]
|
|
||||||
else:
|
|
||||||
raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim))
|
|
||||||
|
|
||||||
# calculate PSNR and SSIM
|
|
||||||
PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255)
|
|
||||||
|
|
||||||
SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255)
|
|
||||||
print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format(
|
|
||||||
i + 1, base_name, PSNR, SSIM))
|
|
||||||
PSNR_all.append(PSNR)
|
|
||||||
SSIM_all.append(SSIM)
|
|
||||||
print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format(
|
|
||||||
sum(PSNR_all) / len(PSNR_all),
|
|
||||||
sum(SSIM_all) / len(SSIM_all)))
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_psnr(img1, img2):
|
|
||||||
# img1 and img2 have range [0, 255]
|
|
||||||
img1 = img1.astype(np.float64)
|
|
||||||
img2 = img2.astype(np.float64)
|
|
||||||
mse = np.mean((img1 - img2)**2)
|
|
||||||
if mse == 0:
|
|
||||||
return float('inf')
|
|
||||||
return 20 * math.log10(255.0 / math.sqrt(mse))
|
|
||||||
|
|
||||||
|
|
||||||
def ssim(img1, img2):
|
|
||||||
C1 = (0.01 * 255)**2
|
|
||||||
C2 = (0.03 * 255)**2
|
|
||||||
|
|
||||||
img1 = img1.astype(np.float64)
|
|
||||||
img2 = img2.astype(np.float64)
|
|
||||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
|
||||||
window = np.outer(kernel, kernel.transpose())
|
|
||||||
|
|
||||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
|
||||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
|
||||||
mu1_sq = mu1**2
|
|
||||||
mu2_sq = mu2**2
|
|
||||||
mu1_mu2 = mu1 * mu2
|
|
||||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
|
||||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
|
||||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
|
||||||
|
|
||||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
|
||||||
(sigma1_sq + sigma2_sq + C2))
|
|
||||||
return ssim_map.mean()
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_ssim(img1, img2):
|
|
||||||
'''calculate SSIM
|
|
||||||
the same outputs as MATLAB's
|
|
||||||
img1, img2: [0, 255]
|
|
||||||
'''
|
|
||||||
if not img1.shape == img2.shape:
|
|
||||||
raise ValueError('Input images must have the same dimensions.')
|
|
||||||
if img1.ndim == 2:
|
|
||||||
return ssim(img1, img2)
|
|
||||||
elif img1.ndim == 3:
|
|
||||||
if img1.shape[2] == 3:
|
|
||||||
ssims = []
|
|
||||||
for i in range(3):
|
|
||||||
ssims.append(ssim(img1, img2))
|
|
||||||
return np.array(ssims).mean()
|
|
||||||
elif img1.shape[2] == 1:
|
|
||||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
|
||||||
else:
|
|
||||||
raise ValueError('Wrong input image dimensions.')
|
|
||||||
|
|
||||||
|
|
||||||
def bgr2ycbcr(img, only_y=True):
|
|
||||||
'''same as matlab rgb2ycbcr
|
|
||||||
only_y: only return Y channel
|
|
||||||
Input:
|
|
||||||
uint8, [0, 255]
|
|
||||||
float, [0, 1]
|
|
||||||
'''
|
|
||||||
in_img_type = img.dtype
|
|
||||||
img.astype(np.float32)
|
|
||||||
if in_img_type != np.uint8:
|
|
||||||
img *= 255.
|
|
||||||
# convert
|
|
||||||
if only_y:
|
|
||||||
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
|
||||||
else:
|
|
||||||
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
|
||||||
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
|
||||||
if in_img_type == np.uint8:
|
|
||||||
rlt = rlt.round()
|
|
||||||
else:
|
|
||||||
rlt /= 255.
|
|
||||||
return rlt.astype(in_img_type)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
|
@ -1 +0,0 @@
|
||||||
Subproject commit db2b7899ea8506e90418dbd389300c49bdbb55c3
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from models.archs.srflow_orig import thops
|
from models.srflow_orig import thops
|
||||||
|
|
||||||
|
|
||||||
class _ActNorm(nn.Module):
|
class _ActNorm(nn.Module):
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from models.archs.srflow_orig import thops
|
from models.srflow_orig import thops
|
||||||
from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros
|
from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
|
@ -2,8 +2,6 @@ import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
import models.archs.srflow_orig.Permutations
|
import models.archs.srflow_orig.Permutations
|
||||||
from models.archs.srflow_orig import flow, thops, FlowAffineCouplingsAblation
|
|
||||||
from utils.util import opt_get
|
|
||||||
|
|
||||||
|
|
||||||
def getConditional(rrdbResults, position):
|
def getConditional(rrdbResults, position):
|
|
@ -3,7 +3,8 @@ import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
import models.archs.srflow_orig.Split
|
import models.archs.srflow_orig.Split
|
||||||
from models.archs.srflow_orig import flow, thops
|
from models.archs.srflow_orig import flow
|
||||||
|
from models.srflow_orig import thops
|
||||||
from models.archs.srflow_orig.Split import Split2d
|
from models.archs.srflow_orig.Split import Split2d
|
||||||
from models.archs.srflow_orig.glow_arch import f_conv2d_bias
|
from models.archs.srflow_orig.glow_arch import f_conv2d_bias
|
||||||
from models.archs.srflow_orig.FlowStep import FlowStep
|
from models.archs.srflow_orig.FlowStep import FlowStep
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from models.archs.srflow_orig import thops
|
from models.srflow_orig import thops
|
||||||
|
|
||||||
|
|
||||||
class InvertibleConv1x1(nn.Module):
|
class InvertibleConv1x1(nn.Module):
|
|
@ -3,11 +3,10 @@ import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
||||||
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
|
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
|
||||||
import models.archs.srflow_orig.thops as thops
|
import models.srflow_orig.thops as thops
|
||||||
import models.archs.srflow_orig.flow as flow
|
import models.archs.srflow_orig.flow as flow
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from models.archs.srflow_orig import thops
|
from models.srflow_orig import thops
|
||||||
from models.archs.srflow_orig.FlowStep import FlowStep
|
|
||||||
from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag
|
from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from models.injectors import Injector
|
from trainer.injectors import Injector
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ import torchvision.transforms.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from utils import options as option
|
from utils import options as option
|
||||||
import utils.util as util
|
import utils.util as util
|
||||||
from data import create_dataloader
|
from data import create_dataloader
|
||||||
|
|
|
@ -2,7 +2,6 @@ import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -14,9 +13,8 @@ from tqdm import tqdm
|
||||||
import utils.options as option
|
import utils.options as option
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
from data import create_dataset, create_dataloader
|
|
||||||
from data.image_corruptor import ImageCorruptor
|
from data.image_corruptor import ImageCorruptor
|
||||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from utils import util
|
from utils import util
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
@ -12,15 +10,10 @@ import torchvision
|
||||||
import utils
|
import utils
|
||||||
import utils.options as option
|
import utils.options as option
|
||||||
import utils.util as util
|
import utils.util as util
|
||||||
from data.util import bgr2ycbcr
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
|
||||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
|
||||||
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
|
||||||
from switched_conv.switched_conv import compute_attention_specificity
|
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
import models.networks as networks
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#### options
|
#### options
|
||||||
|
|
|
@ -5,7 +5,7 @@ import utils
|
||||||
import utils.options as option
|
import utils.options as option
|
||||||
import utils.util as util
|
import utils.util as util
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
|
|
||||||
|
|
||||||
class PretrainedImagePatchClassifier:
|
class PretrainedImagePatchClassifier:
|
||||||
|
|
|
@ -6,8 +6,8 @@ import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from models.networks import define_F
|
from trainer.networks import define_F
|
||||||
from utils import options as option
|
from utils import options as option
|
||||||
import utils.util as util
|
import utils.util as util
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
|
|
|
@ -4,20 +4,13 @@ import time
|
||||||
import argparse
|
import argparse
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
import utils.options as option
|
import utils.options as option
|
||||||
import utils.util as util
|
import utils.util as util
|
||||||
from data.util import bgr2ycbcr
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
|
||||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
|
||||||
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
|
||||||
from switched_conv.switched_conv import compute_attention_specificity
|
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
import models.networks as networks
|
|
||||||
|
|
||||||
|
|
||||||
def forward_pass(model, output_dir, alteration_suffix=''):
|
def forward_pass(model, output_dir, alteration_suffix=''):
|
||||||
|
|
|
@ -2,22 +2,16 @@ import os.path as osp
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
import utils.options as option
|
import utils.options as option
|
||||||
import utils.util as util
|
import utils.util as util
|
||||||
from data.util import bgr2ycbcr
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
|
||||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
|
||||||
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
|
||||||
from switched_conv.switched_conv import compute_attention_specificity
|
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
import models.networks as networks
|
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,11 +7,11 @@ from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from data.data_sampler import DistIterSampler
|
from data.data_sampler import DistIterSampler
|
||||||
from models.eval import create_evaluator
|
from trainer.eval import create_evaluator
|
||||||
|
|
||||||
from utils import util, options as option
|
from utils import util, options as option
|
||||||
from data import create_dataloader, create_dataset
|
from data import create_dataloader, create_dataset
|
||||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
def init_dist(backend, **kwargs):
|
def init_dist(backend, **kwargs):
|
||||||
|
|
|
@ -5,12 +5,12 @@ import torch
|
||||||
from torch.nn.parallel import DataParallel
|
from torch.nn.parallel import DataParallel
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import models.lr_scheduler as lr_scheduler
|
import trainer.lr_scheduler as lr_scheduler
|
||||||
import models.networks as networks
|
import trainer.networks as networks
|
||||||
from models.base_model import BaseModel
|
from trainer.base_model import BaseModel
|
||||||
from models.injectors import create_injector
|
from trainer.injectors import create_injector
|
||||||
from models.steps import ConfigurableStep
|
from trainer.steps import ConfigurableStep
|
||||||
from models.experiments.experiments import get_experiment_for_name
|
from trainer.experiments.experiments import get_experiment_for_name
|
||||||
import torchvision.utils as utils
|
import torchvision.utils as utils
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
0
codes/trainer/__init__.py
Normal file
0
codes/trainer/__init__.py
Normal file
|
@ -6,8 +6,8 @@ import torchvision
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from data.multiscale_dataset import build_multiscale_patch_index_map
|
from data.multiscale_dataset import build_multiscale_patch_index_map
|
||||||
from models.injectors import Injector
|
from trainer.injectors import Injector
|
||||||
from models.losses import extract_params_from_state
|
from trainer.losses import extract_params_from_state
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ class ProgressiveGeneratorInjector(Injector):
|
||||||
lbl = 'generator_recurrent'
|
lbl = 'generator_recurrent'
|
||||||
else:
|
else:
|
||||||
lbl = 'generator_regular'
|
lbl = 'generator_regular'
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", lbl, str(self.env['step']))
|
base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", lbl, str(self.env['step']))
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
ind = 1
|
ind = 1
|
||||||
for i, o in zip(chain_inputs, chain_outputs):
|
for i, o in zip(chain_inputs, chain_outputs):
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
from models.archs.flownet2.networks import Resample2d
|
from models.archs.flownet2.networks import Resample2d
|
||||||
from models.archs.flownet2 import flow2img
|
from models.archs.flownet2 import flow2img
|
||||||
from models.injectors import Injector
|
from trainer.injectors import Injector
|
||||||
|
|
||||||
|
|
||||||
def create_stereoscopic_injector(opt, env):
|
def create_stereoscopic_injector(opt, env):
|
|
@ -1,9 +1,9 @@
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
|
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||||
from models.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||||
from models.archs.flownet2.networks import Resample2d
|
from models.archs.flownet2.networks import Resample2d
|
||||||
from models.injectors import Injector
|
from trainer.injectors import Injector
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import os
|
import os
|
||||||
|
@ -156,7 +156,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
|
def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
|
||||||
if self.env['rank'] > 0:
|
if self.env['rank'] > 0:
|
||||||
return
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", "teco_geninput", str(self.env['step']))
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
torchvision.utils.save_image(gen_input.float(), osp.join(base_path, "%s_img.png" % (it,)))
|
torchvision.utils.save_image(gen_input.float(), osp.join(base_path, "%s_img.png" % (it,)))
|
||||||
torchvision.utils.save_image(gen_recurrent.float(), osp.join(base_path, "%s_recurrent.png" % (it,)))
|
torchvision.utils.save_image(gen_recurrent.float(), osp.join(base_path, "%s_recurrent.png" % (it,)))
|
||||||
|
@ -345,7 +345,7 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
def produce_teco_visual_debugs(self, sext, lbl, it):
|
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||||
if self.env['rank'] > 0:
|
if self.env['rank'] > 0:
|
||||||
return
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
|
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
|
||||||
for i in range(6):
|
for i in range(6):
|
||||||
|
@ -378,7 +378,7 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
def produce_teco_visual_debugs(self, imglist):
|
def produce_teco_visual_debugs(self, imglist):
|
||||||
if self.env['rank'] > 0:
|
if self.env['rank'] > 0:
|
||||||
return
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
cnt = imglist.shape[1]
|
cnt = imglist.shape[1]
|
||||||
for i in range(cnt):
|
for i in range(cnt):
|
||||||
|
@ -388,7 +388,7 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
def produce_teco_visual_debugs2(self, imga, imgb, i):
|
def produce_teco_visual_debugs2(self, imga, imgb, i):
|
||||||
if self.env['rank'] > 0:
|
if self.env['rank'] > 0:
|
||||||
return
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
torchvision.utils.save_image(imga.float(), osp.join(base_path, "%s_a.png" % (i, )))
|
torchvision.utils.save_image(imga.float(), osp.join(base_path, "%s_a.png" % (i, )))
|
||||||
torchvision.utils.save_image(imgb.float(), osp.join(base_path, "%s_b.png" % (i, )))
|
torchvision.utils.save_image(imgb.float(), osp.join(base_path, "%s_b.png" % (i, )))
|
|
@ -1,6 +1,6 @@
|
||||||
from models.eval.flow_gaussian_nll import FlowGaussianNll
|
from trainer.eval.flow_gaussian_nll import FlowGaussianNll
|
||||||
from models.eval.sr_style import SrStyleTransferEvaluator
|
from trainer.eval.sr_style import SrStyleTransferEvaluator
|
||||||
from models.eval.style import StyleTransferEvaluator
|
from trainer.eval import StyleTransferEvaluator
|
||||||
|
|
||||||
|
|
||||||
def create_evaluator(model, opt_eval, env):
|
def create_evaluator(model, opt_eval, env):
|
|
@ -1,14 +1,8 @@
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import os.path as osp
|
|
||||||
import torchvision
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import models.eval.evaluator as evaluator
|
import trainer.eval.evaluator as evaluator
|
||||||
from pytorch_fid import fid_score
|
|
||||||
|
|
||||||
|
|
||||||
# Evaluate how close to true Gaussian a flow network predicts in a "normal" pass given a LQ/HQ image pair.
|
# Evaluate how close to true Gaussian a flow network predicts in a "normal" pass given a LQ/HQ image pair.
|
||||||
from data.image_folder_dataset import ImageFolderDataset
|
from data.image_folder_dataset import ImageFolderDataset
|
|
@ -5,7 +5,7 @@ import os.path as osp
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch.utils.data import BatchSampler
|
from torch.utils.data import BatchSampler
|
||||||
|
|
||||||
import models.eval.evaluator as evaluator
|
import trainer.eval.evaluator as evaluator
|
||||||
from pytorch_fid import fid_score
|
from pytorch_fid import fid_score
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,9 +32,9 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
|
||||||
|
|
||||||
def perform_eval(self):
|
def perform_eval(self):
|
||||||
embedding_generator = self.env['generators'][self.embedding_generator]
|
embedding_generator = self.env['generators'][self.embedding_generator]
|
||||||
fid_fake_path = osp.join(self.env['base_path'], "..", "fid_fake", str(self.env["step"]))
|
fid_fake_path = osp.join(self.env['base_path'], "../../models", "fid_fake", str(self.env["step"]))
|
||||||
os.makedirs(fid_fake_path, exist_ok=True)
|
os.makedirs(fid_fake_path, exist_ok=True)
|
||||||
fid_real_path = osp.join(self.env['base_path'], "..", "fid_real", str(self.env["step"]))
|
fid_real_path = osp.join(self.env['base_path'], "../../models", "fid_real", str(self.env["step"]))
|
||||||
os.makedirs(fid_real_path, exist_ok=True)
|
os.makedirs(fid_real_path, exist_ok=True)
|
||||||
counter = 0
|
counter = 0
|
||||||
for batch in self.sampler:
|
for batch in self.sampler:
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import torchvision
|
import torchvision
|
||||||
import models.eval.evaluator as evaluator
|
import trainer.eval.evaluator as evaluator
|
||||||
from pytorch_fid import fid_score
|
from pytorch_fid import fid_score
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ class StyleTransferEvaluator(evaluator.Evaluator):
|
||||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||||
|
|
||||||
def perform_eval(self):
|
def perform_eval(self):
|
||||||
fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
|
fid_fake_path = osp.join(self.env['base_path'], "../../models", "fid", str(self.env["step"]))
|
||||||
os.makedirs(fid_fake_path, exist_ok=True)
|
os.makedirs(fid_fake_path, exist_ok=True)
|
||||||
counter = 0
|
counter = 0
|
||||||
for i in range(self.batches_per_eval):
|
for i in range(self.batches_per_eval):
|
0
codes/trainer/experiments/__init__.py
Normal file
0
codes/trainer/experiments/__init__.py
Normal file
|
@ -3,8 +3,8 @@ from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import models.networks as networks
|
import trainer.networks as networks
|
||||||
import models.lr_scheduler as lr_scheduler
|
import trainer.lr_scheduler as lr_scheduler
|
||||||
from .base_model import BaseModel
|
from .base_model import BaseModel
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
|
@ -4,19 +4,19 @@ import torch.nn
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from utils.weight_scheduler import get_scheduler_for_opt
|
from utils.weight_scheduler import get_scheduler_for_opt
|
||||||
from models.losses import extract_params_from_state
|
from trainer.losses import extract_params_from_state
|
||||||
|
|
||||||
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
||||||
def create_injector(opt_inject, env):
|
def create_injector(opt_inject, env):
|
||||||
type = opt_inject['type']
|
type = opt_inject['type']
|
||||||
if 'teco_' in type:
|
if 'teco_' in type:
|
||||||
from models.custom_training_components import create_teco_injector
|
from trainer.custom_training_components import create_teco_injector
|
||||||
return create_teco_injector(opt_inject, env)
|
return create_teco_injector(opt_inject, env)
|
||||||
elif 'progressive_' in type:
|
elif 'progressive_' in type:
|
||||||
from models.custom_training_components import create_progressive_zoom_injector
|
from trainer.custom_training_components import create_progressive_zoom_injector
|
||||||
return create_progressive_zoom_injector(opt_inject, env)
|
return create_progressive_zoom_injector(opt_inject, env)
|
||||||
elif 'stereoscopic_' in type:
|
elif 'stereoscopic_' in type:
|
||||||
from models.custom_training_components import create_stereoscopic_injector
|
from trainer.custom_training_components import create_stereoscopic_injector
|
||||||
return create_stereoscopic_injector(opt_inject, env)
|
return create_stereoscopic_injector(opt_inject, env)
|
||||||
elif 'igpt' in type:
|
elif 'igpt' in type:
|
||||||
from models.archs.transformers.igpt import gpt2
|
from models.archs.transformers.igpt import gpt2
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from models.loss import GANLoss
|
from trainer.loss import GANLoss
|
||||||
import random
|
import random
|
||||||
import functools
|
import functools
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -11,7 +11,7 @@ import torch.nn.functional as F
|
||||||
def create_loss(opt_loss, env):
|
def create_loss(opt_loss, env):
|
||||||
type = opt_loss['type']
|
type = opt_loss['type']
|
||||||
if 'teco_' in type:
|
if 'teco_' in type:
|
||||||
from models.custom_training_components.tecogan_losses import create_teco_loss
|
from trainer.custom_training_components import create_teco_loss
|
||||||
return create_teco_loss(opt_loss, env)
|
return create_teco_loss(opt_loss, env)
|
||||||
elif 'stylegan2_' in type:
|
elif 'stylegan2_' in type:
|
||||||
from models.archs.stylegan import create_stylegan2_loss
|
from models.archs.stylegan import create_stylegan2_loss
|
||||||
|
@ -152,9 +152,9 @@ class FeatureLoss(ConfigurableLoss):
|
||||||
super(FeatureLoss, self).__init__(opt, env)
|
super(FeatureLoss, self).__init__(opt, env)
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||||
import models.networks
|
import trainer.networks
|
||||||
self.netF = models.networks.define_F(which_model=opt['which_model_F'],
|
self.netF = trainer.networks.define_F(which_model=opt['which_model_F'],
|
||||||
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
|
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
|
||||||
if not env['opt']['dist']:
|
if not env['opt']['dist']:
|
||||||
self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids'])
|
self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids'])
|
||||||
|
|
||||||
|
@ -178,9 +178,9 @@ class InterpretedFeatureLoss(ConfigurableLoss):
|
||||||
super(InterpretedFeatureLoss, self).__init__(opt, env)
|
super(InterpretedFeatureLoss, self).__init__(opt, env)
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||||
import models.networks
|
import trainer.networks
|
||||||
self.netF_real = models.networks.define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
self.netF_real = trainer.networks.define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
||||||
self.netF_gen = models.networks.define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device'])
|
self.netF_gen = trainer.networks.define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device'])
|
||||||
if not env['opt']['dist']:
|
if not env['opt']['dist']:
|
||||||
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
|
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
|
||||||
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
|
@ -3,10 +3,10 @@ from torch.cuda.amp import GradScaler
|
||||||
from utils.loss_accumulator import LossAccumulator
|
from utils.loss_accumulator import LossAccumulator
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
import logging
|
import logging
|
||||||
from models.losses import create_loss
|
from trainer.losses import create_loss
|
||||||
import torch
|
import torch
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from models.injectors import create_injector
|
from trainer.injectors import create_injector
|
||||||
from utils.util import recursively_detach
|
from utils.util import recursively_detach
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
|
@ -2,7 +2,7 @@ import argparse
|
||||||
import functools
|
import functools
|
||||||
import torch
|
import torch
|
||||||
from utils import options as option
|
from utils import options as option
|
||||||
from models.networks import define_G
|
from trainer.networks import define_G
|
||||||
|
|
||||||
|
|
||||||
class TracedModule:
|
class TracedModule:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user