forked from mrq/DL-Art-School
7629cb0e61
New loss type that can replace PSNR loss. Works against the frequency domain and focuses on frequency features loss during hr->lr conversion.
75 lines
2.8 KiB
Python
75 lines
2.8 KiB
Python
import torch
|
|
import os
|
|
from PIL import Image
|
|
import numpy as np
|
|
import options.options as option
|
|
from data import create_dataloader, create_dataset
|
|
import math
|
|
from tqdm import tqdm
|
|
from torchvision import transforms
|
|
from utils.fdpl_util import dct_2d, extract_patches_2d
|
|
import random
|
|
import matplotlib.pyplot as plt
|
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
from utils.colors import rgb2ycbcr
|
|
import torch.nn.functional as F
|
|
|
|
input_config = "../../options/train_imgset_pixgan_srg4_fdpl.yml"
|
|
output_file = "fdpr_diff_means.pt"
|
|
device = 'cuda'
|
|
patch_size=128
|
|
|
|
if __name__ == '__main__':
|
|
opt = option.parse(input_config, is_train=True)
|
|
opt['dist'] = False
|
|
|
|
# Create a dataset to load from (this dataset loads HR/LR images and performs any distortions specified by the YML.
|
|
dataset_opt = opt['datasets']['train']
|
|
train_set = create_dataset(dataset_opt)
|
|
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
|
|
total_iters = int(opt['train']['niter'])
|
|
total_epochs = int(math.ceil(total_iters / train_size))
|
|
train_loader = create_dataloader(train_set, dataset_opt, opt, None)
|
|
print('Number of train images: {:,d}, iters: {:,d}'.format(
|
|
len(train_set), train_size))
|
|
|
|
# calculate the perceptual weights
|
|
master_diff = np.zeros((patch_size, patch_size))
|
|
num_patches = 0
|
|
all_diff_patches = []
|
|
tq = tqdm(train_loader)
|
|
sampled = 0
|
|
for train_data in tq:
|
|
if sampled > 200:
|
|
break
|
|
sampled += 1
|
|
|
|
im = rgb2ycbcr(train_data['GT'].double())
|
|
im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(),
|
|
size=im.shape[2:],
|
|
mode="bicubic"))
|
|
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)
|
|
patches_hr = dct_2d(patches_hr, norm='ortho')
|
|
patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True)
|
|
patches_lr = dct_2d(patches_lr, norm='ortho')
|
|
b, p, c, w, h = patches_hr.shape
|
|
diffs = torch.abs(patches_lr - patches_hr) / ((torch.abs(patches_lr) + torch.abs(patches_hr)) / 2 + .00000001)
|
|
num_patches += b * p
|
|
all_diff_patches.append(torch.sum(diffs, dim=(0, 1)))
|
|
|
|
diff_patches = torch.stack(all_diff_patches, dim=0)
|
|
diff_means = torch.sum(diff_patches, dim=0) / num_patches
|
|
|
|
torch.save(diff_means, output_file)
|
|
print(diff_means)
|
|
|
|
for i in range(3):
|
|
fig, ax = plt.subplots()
|
|
divider = make_axes_locatable(ax)
|
|
cax = divider.append_axes('right', size='5%', pad=0.05)
|
|
im = ax.imshow(diff_means[i].numpy())
|
|
ax.set_title("mean_diff for channel %i" % (i,))
|
|
fig.colorbar(im, cax=cax, orientation='vertical')
|
|
plt.show()
|
|
|