Add FDPL Loss

New loss type that can replace PSNR loss. Works against the frequency domain
and focuses on frequency features loss during hr->lr conversion.
This commit is contained in:
James Betker 2020-07-30 20:47:57 -06:00
parent 85ee64b8d9
commit 7629cb0e61
6 changed files with 737 additions and 3 deletions

View File

@ -172,7 +172,7 @@ class LQGTDataset(data.Dataset):
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
if self.opt['doResizeLoss']:
if 'doResizeLoss' in self.opt.keys() and self.opt['doResizeLoss']:
r = random.randrange(0, 10)
if r > 5:
img_LQ = cv2.resize(img_LQ, (int(LQ_size/2), int(LQ_size/2)), interpolation=cv2.INTER_LINEAR)
@ -215,7 +215,7 @@ class LQGTDataset(data.Dataset):
corruption_buffer.seek(0)
img_LQ = Image.open(corruption_buffer)
if self.opt['grayscale']:
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()

View File

@ -0,0 +1,74 @@
import torch
import os
from PIL import Image
import numpy as np
import options.options as option
from data import create_dataloader, create_dataset
import math
from tqdm import tqdm
from torchvision import transforms
from utils.fdpl_util import dct_2d, extract_patches_2d
import random
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from utils.colors import rgb2ycbcr
import torch.nn.functional as F
input_config = "../../options/train_imgset_pixgan_srg4_fdpl.yml"
output_file = "fdpr_diff_means.pt"
device = 'cuda'
patch_size=128
if __name__ == '__main__':
opt = option.parse(input_config, is_train=True)
opt['dist'] = False
# Create a dataset to load from (this dataset loads HR/LR images and performs any distortions specified by the YML.
dataset_opt = opt['datasets']['train']
train_set = create_dataset(dataset_opt)
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
total_iters = int(opt['train']['niter'])
total_epochs = int(math.ceil(total_iters / train_size))
train_loader = create_dataloader(train_set, dataset_opt, opt, None)
print('Number of train images: {:,d}, iters: {:,d}'.format(
len(train_set), train_size))
# calculate the perceptual weights
master_diff = np.zeros((patch_size, patch_size))
num_patches = 0
all_diff_patches = []
tq = tqdm(train_loader)
sampled = 0
for train_data in tq:
if sampled > 200:
break
sampled += 1
im = rgb2ycbcr(train_data['GT'].double())
im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(),
size=im.shape[2:],
mode="bicubic"))
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)
patches_hr = dct_2d(patches_hr, norm='ortho')
patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True)
patches_lr = dct_2d(patches_lr, norm='ortho')
b, p, c, w, h = patches_hr.shape
diffs = torch.abs(patches_lr - patches_hr) / ((torch.abs(patches_lr) + torch.abs(patches_hr)) / 2 + .00000001)
num_patches += b * p
all_diff_patches.append(torch.sum(diffs, dim=(0, 1)))
diff_patches = torch.stack(all_diff_patches, dim=0)
diff_means = torch.sum(diff_patches, dim=0) / num_patches
torch.save(diff_means, output_file)
print(diff_means)
for i in range(3):
fig, ax = plt.subplots()
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
im = ax.imshow(diff_means[i].numpy())
ax.set_title("mean_diff for channel %i" % (i,))
fig.colorbar(im, cax=cax, orientation='vertical')
plt.show()

View File

@ -6,7 +6,7 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel
import models.networks as networks
import models.lr_scheduler as lr_scheduler
from models.base_model import BaseModel
from models.loss import GANLoss
from models.loss import GANLoss, FDPLLoss
from apex import amp
from data.weight_scheduler import get_scheduler_for_opt
import torch.nn.functional as F
@ -62,6 +62,16 @@ class SRGANModel(BaseModel):
logger.info('Remove pixel loss.')
self.cri_pix = None
# FDPL loss.
if 'fdpl_loss' in train_opt.keys():
fdpl_opt = train_opt['fdpl_loss']
self.fdpl_weight = fdpl_opt['weight']
self.fdpl_enabled = self.fdpl_weight > 0
if self.fdpl_enabled:
self.cri_fdpl = FDPLLoss(fdpl_opt['data_mean'], self.device)
else:
self.fdpl_enabled = False
# G feature loss
if train_opt['feature_weight'] and train_opt['feature_weight'] > 0:
# For backwards compatibility, use a scheduler definition instead. Remove this at some point.
@ -305,10 +315,14 @@ class SRGANModel(BaseModel):
if using_gan_img:
l_g_pix_log = None
l_g_fea_log = None
l_g_fdpl = None
if self.cri_pix and not using_gan_img: # pixel loss
l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix)
l_g_pix_log = l_g_pix / self.l_pix_w
l_g_total += l_g_pix
if self.fdpl_enabled and not using_gan_img:
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
l_g_total += l_g_fdpl * self.fdpl_weight
if self.cri_fea and not using_gan_img: # feature loss
real_fea = self.netF(pix).detach()
fake_fea = self.netF(fea_GenOut)
@ -535,6 +549,8 @@ class SRGANModel(BaseModel):
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
if self.cri_pix and l_g_pix_log is not None:
self.add_log_entry('l_g_pix', l_g_pix_log.item())
if self.fdpl_enabled and l_g_fdpl is not None:
self.add_log_entry('l_g_fdpl', l_g_fdpl.item())
if self.cri_fea and l_g_fea_log is not None:
self.add_log_entry('feature_weight', fea_w)
self.add_log_entry('l_g_fea', l_g_fea_log.item())

View File

@ -1,5 +1,8 @@
import torch
import torch.nn as nn
import numpy as np
from utils.fdpl_util import extract_patches_2d, dct_2d
from utils.colors import rgb2ycbcr
class CharbonnierLoss(nn.Module):
@ -75,3 +78,99 @@ class GradientPenaltyLoss(nn.Module):
loss = ((grad_interp_norm - 1)**2).mean()
return loss
# Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL
# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see data_scripts/compute_fdpl_perceptual_weights.py
# In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs).
class FDPLLoss(nn.Module):
"""
Loss function taking the MSE between the 2D DCT coefficients
of predicted and target images or an image channel.
DCT coefficients are computed for each 8x8 block of the image
Important note about this loss: Since it operates in the frequency domain, precision is highly important.
It works on FP64 numbers and will fail if you attempt to use it with AMP. Recommend you split this loss
off from the rest of amp.scale_loss().
"""
def __init__(self, dataset_diff_means_file, device):
"""
dataset_diff_means (torch.tensor): Pre-computed frequency-domain mean differences between LR and HR images.
"""
# These values are derived from the JPEG standard.
qt_Y = torch.tensor([[16, 11, 10, 16, 24, 40, 51, 61],
[12, 12, 14, 19, 26, 58, 60, 55],
[14, 13, 16, 24, 40, 57, 69, 56],
[14, 17, 22, 29, 51, 87, 80, 62],
[18, 22, 37, 56, 68, 109, 103, 77],
[24, 35, 55, 64, 81, 104, 113, 92],
[49, 64, 78, 87, 103, 121, 120, 101],
[72, 92, 95, 98, 112, 100, 103, 99]],
dtype=torch.double,
device=device,
requires_grad=False)
qt_C = torch.tensor([[17, 18, 24, 47, 99, 99, 99, 99],
[18, 21, 26, 66, 99, 99, 99, 99],
[24, 26, 56, 99, 99, 99, 99, 99],
[47, 66, 99, 99, 99, 99, 99, 99],
[99, 99, 99, 99, 99, 99, 99, 99],
[99, 99, 99, 99, 99, 99, 99, 99],
[99, 99, 99, 99, 99, 99, 99, 99],
[99, 99, 99, 99, 99, 99, 99, 99]],
dtype=torch.double,
device=device,
requires_grad=False)
"""
Reasoning behind this perceptual weight matrix: JPEG gives as a model of frequencies that are important
for human perception. In that model, lower frequencies are more important than higher frequencies. Because
of this, the higher frequencies are the first to go during compression. As compression increases, the affect
spreads to the lower frequencies, which degrades perceptual quality. But when the lower frequencies are
preserved, JPEG does an excellent job of compression without a noticeable loss of quality.
In super resolution, we already have the low frequencies. In fact that is really all we have in the low
resolution images.
As evidenced by the diff_means matrix above, what is lost in the SR images is the mid-range frequencies;
those across and towards the centre of the diagonal. We can bias our model to recover these frequencies
by having our loss function prioritize these coefficients, with priority determined by the magnitude of
relative change between the low-res and high-res images. But we can take this further and into a true
preceptual loss by further prioritizing DCT coefficients by the importance that has been assigned to them
by the JPEG quantization table. That is how the table below is created.
The problem is that we don't know if the JPEG model is optimal. So there is room for qualitative evaluation
of the quantization table values. We can further our perspective weights deleting each in turn for a small
set of images and evaluating the resulting change in percieved quality. I can do this on my own to start and
if it works, I can do a small user study to determine this.
"""
diff_means = torch.tensor(torch.load(dataset_diff_means_file), device=device)
perceptual_weights = torch.stack([(torch.ones_like(qt_Y, device=device) / qt_Y),
(torch.ones_like(qt_C, device=device) / qt_C),
(torch.ones_like(qt_C, device=device) / qt_C)])
perceptual_weights = perceptual_weights * diff_means
self.perceptual_weights = perceptual_weights / torch.mean(perceptual_weights)
super(FDPLLoss, self).__init__()
def forward(self, predictions, targets):
"""
Args:
predictions (torch.tensor): output of an image transformation model.
shape: batch_size x 3 x H x W
targets (torch.tensor): ground truth images corresponding to outputs
shape: batch_size x 3 x H x W
criterion (torch.nn.MSELoss): object used to calculate MSE
Returns:
loss (float): MSE between predicted and ground truth 2D DCT coefficients
"""
# transition to fp64 and then convert to YCC color space.
predictions = rgb2ycbcr(predictions.double())
targets = rgb2ycbcr(targets.double())
# get DCT coefficients of ground truth patches
patches = extract_patches_2d(img=targets, patch_shape=(8, 8), batch_first=True)
ground_truth_dct = dct_2d(patches, norm='ortho')
# get DCT coefficients of transformed images
patches = extract_patches_2d(img=predictions, patch_shape=(8, 8), batch_first=True)
outputs_dct = dct_2d(patches, norm='ortho')
loss = torch.sum(((outputs_dct - ground_truth_dct).pow(2)) * self.perceptual_weights)
return loss

409
codes/utils/colors.py Normal file
View File

@ -0,0 +1,409 @@
# Differentiable color conversions from https://github.com/TheZino/pytorch-color-conversions
from warnings import warn
import numpy as np
import torch
from scipy import linalg
xyz_from_rgb = torch.Tensor([[0.412453, 0.357580, 0.180423],
[0.212671, 0.715160, 0.072169],
[0.019334, 0.119193, 0.950227]])
rgb_from_xyz = torch.Tensor(linalg.inv(xyz_from_rgb))
illuminants = \
{"A": {'2': torch.Tensor([(1.098466069456375, 1, 0.3558228003436005)]),
'10': torch.Tensor([(1.111420406956693, 1, 0.3519978321919493)])},
"D50": {'2': torch.Tensor([(0.9642119944211994, 1, 0.8251882845188288)]),
'10': torch.Tensor([(0.9672062750333777, 1, 0.8142801513128616)])},
"D55": {'2': torch.Tensor([(0.956797052643698, 1, 0.9214805860173273)]),
'10': torch.Tensor([(0.9579665682254781, 1, 0.9092525159847462)])},
"D65": {'2': torch.Tensor([(0.95047, 1., 1.08883)]), # This was: `lab_ref_white`
'10': torch.Tensor([(0.94809667673716, 1, 1.0730513595166162)])},
"D75": {'2': torch.Tensor([(0.9497220898840717, 1, 1.226393520724154)]),
'10': torch.Tensor([(0.9441713925645873, 1, 1.2064272211720228)])},
"E": {'2': torch.Tensor([(1.0, 1.0, 1.0)]),
'10': torch.Tensor([(1.0, 1.0, 1.0)])}}
# -------------------------------------------------------------
# The conversion functions that make use of the matrices above
# -------------------------------------------------------------
##### RGB - YCbCr
# Helper for the creation of module-global constant tensors
def _t(data):
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # TODO inherit this
device = torch.device("cpu") # TODO inherit this
return torch.tensor(data, requires_grad=False, dtype=torch.float32, device=device)
# Helper for color matrix multiplication
def _mul(coeffs, image):
# This is implementation is clearly suboptimal. The function will
# be implemented with 'einsum' when a bug in pytorch 0.4.0 will be
# fixed (Einsum modifies variables in-place #7763).
coeffs = coeffs.to(image.device)
r0 = image[:, 0:1, :, :].repeat(1, 3, 1, 1) * coeffs[:, 0].view(1, 3, 1, 1)
r1 = image[:, 1:2, :, :].repeat(1, 3, 1, 1) * coeffs[:, 1].view(1, 3, 1, 1)
r2 = image[:, 2:3, :, :].repeat(1, 3, 1, 1) * coeffs[:, 2].view(1, 3, 1, 1)
return r0 + r1 + r2
# return torch.einsum("dc,bcij->bdij", (coeffs.to(image.device), image))
_RGB_TO_YCBCR = _t([[0.257, 0.504, 0.098], [-0.148, -0.291, 0.439], [0.439 , -0.368, -0.071]])
_YCBCR_OFF = _t([0.063, 0.502, 0.502]).view(1, 3, 1, 1)
def rgb2ycbcr(rgb):
"""sRGB to YCbCr conversion."""
clip_rgb=False
if clip_rgb:
rgb = torch.clamp(rgb, 0, 1)
return _mul(_RGB_TO_YCBCR, rgb) + _YCBCR_OFF.to(rgb.device)
def ycbcr2rgb(rgb):
"""YCbCr to sRGB conversion."""
clip_rgb=False
rgb = _mul(torch.inverse(_RGB_TO_YCBCR), rgb - _YCBCR_OFF.to(rgb.device))
if clip_rgb:
rgb = torch.clamp(rgb, 0, 1)
return rgb
##### HSV - RGB
def rgb2hsv(rgb):
"""
R, G and B input range = 0 ÷ 1.0
H, S and V output range = 0 ÷ 1.0
"""
eps = 1e-7
var_R = rgb[:,0,:,:]
var_G = rgb[:,1,:,:]
var_B = rgb[:,2,:,:]
var_Min = rgb.min(1)[0] #Min. value of RGB
var_Max = rgb.max(1)[0] #Max. value of RGB
del_Max = var_Max - var_Min ##Delta RGB value
H = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
S = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
V = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
V = var_Max
#This is a gray, no chroma...
mask = del_Max == 0
H[mask] = 0
S[mask] = 0
#Chromatic data...
S = del_Max / (var_Max + eps)
del_R = ( ( ( var_Max - var_R ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
del_G = ( ( ( var_Max - var_G ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
del_B = ( ( ( var_Max - var_B ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
H = torch.where( var_R == var_Max , del_B - del_G, H)
H = torch.where( var_G == var_Max , ( 1 / 3 ) + del_R - del_B, H)
H = torch.where( var_B == var_Max ,( 2 / 3 ) + del_G - del_R, H)
# if ( H < 0 ) H += 1
# if ( H > 1 ) H -= 1
return torch.stack([H, S, V], 1)
def hsv2rgb(hsv):
"""
H, S and V input range = 0 ÷ 1.0
R, G and B output range = 0 ÷ 1.0
"""
eps = 1e-7
bb,cc,hh,ww = hsv.shape
H = hsv[:,0,:,:]
S = hsv[:,1,:,:]
V = hsv[:,2,:,:]
# var_h = torch.zeros(bb,hh,ww)
# var_s = torch.zeros(bb,hh,ww)
# var_v = torch.zeros(bb,hh,ww)
# var_r = torch.zeros(bb,hh,ww)
# var_g = torch.zeros(bb,hh,ww)
# var_b = torch.zeros(bb,hh,ww)
# Grayscale
if (S == 0).all():
R = V
G = V
B = V
# Chromatic data
else:
var_h = H * 6
var_h[var_h == 6] = 0 #H must be < 1
var_i = var_h.floor() #Or ... var_i = floor( var_h )
var_1 = V * ( 1 - S )
var_2 = V * ( 1 - S * ( var_h - var_i ) )
var_3 = V * ( 1 - S * ( 1 - ( var_h - var_i ) ) )
# else { var_r = V ; var_g = var_1 ; var_b = var_2 }
var_r = V
var_g = var_1
var_b = var_2
# var_i == 0 { var_r = V ; var_g = var_3 ; var_b = var_1 }
var_r = torch.where(var_i == 0, V, var_r)
var_g = torch.where(var_i == 0, var_3, var_g)
var_b = torch.where(var_i == 0, var_1, var_b)
# else if ( var_i == 1 ) { var_r = var_2 ; var_g = V ; var_b = var_1 }
var_r = torch.where(var_i == 1, var_2, var_r)
var_g = torch.where(var_i == 1, V, var_g)
var_b = torch.where(var_i == 1, var_1, var_b)
# else if ( var_i == 2 ) { var_r = var_1 ; var_g = V ; var_b = var_3 }
var_r = torch.where(var_i == 2, var_1, var_r)
var_g = torch.where(var_i == 2, V, var_g)
var_b = torch.where(var_i == 2, var_3, var_b)
# else if ( var_i == 3 ) { var_r = var_1 ; var_g = var_2 ; var_b = V }
var_r = torch.where(var_i == 3, var_1, var_r)
var_g = torch.where(var_i == 3, var_2, var_g)
var_b = torch.where(var_i == 3, V, var_b)
# else if ( var_i == 4 ) { var_r = var_3 ; var_g = var_1 ; var_b = V }
var_r = torch.where(var_i == 4, var_3, var_r)
var_g = torch.where(var_i == 4, var_1, var_g)
var_b = torch.where(var_i == 4, V, var_b)
R = var_r #* 255
G = var_g #* 255
B = var_b #* 255
return torch.stack([R, G, B], 1)
##### LAB - RGB
def _convert(matrix, arr):
"""Do the color space conversion.
Parameters
----------
matrix : array_like
The 3x3 matrix to use.
arr : array_like
The input array.
Returns
-------
out : ndarray, dtype=float
The converted array.
"""
if arr.is_cuda:
matrix = matrix.cuda()
bs, ch, h, w = arr.shape
arr = arr.permute((0,2,3,1))
arr = arr.contiguous().view(-1,1,3)
matrix = matrix.transpose(0,1).unsqueeze(0)
matrix = matrix.repeat(arr.shape[0],1,1)
res = torch.bmm(arr,matrix)
res = res.view(bs,h,w,ch)
res = res.transpose(3,2).transpose(2,1)
return res
def get_xyz_coords(illuminant, observer):
"""Get the XYZ coordinates of the given illuminant and observer [1]_.
Parameters
----------
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
The name of the illuminant (the function is NOT case sensitive).
observer : {"2", "10"}, optional
The aperture angle of the observer.
Returns
-------
(x, y, z) : tuple
A tuple with 3 elements containing the XYZ coordinates of the given
illuminant.
Raises
------
ValueError
If either the illuminant or the observer angle are not supported or
unknown.
References
----------
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
"""
illuminant = illuminant.upper()
try:
return illuminants[illuminant][observer]
except KeyError:
raise ValueError("Unknown illuminant/observer combination\
(\'{0}\', \'{1}\')".format(illuminant, observer))
def rgb2xyz(rgb):
mask = rgb > 0.04045
rgbm = rgb.clone()
tmp = torch.pow((rgb + 0.055) / 1.055, 2.4)
rgb = torch.where(mask, tmp, rgb)
rgbm = rgb.clone()
rgb[~mask] = rgbm[~mask]/12.92
return _convert(xyz_from_rgb, rgb)
def xyz2lab(xyz, illuminant="D65", observer="2"):
# arr = _prepare_colorarray(xyz)
xyz_ref_white = get_xyz_coords(illuminant, observer)
#cuda
if xyz.is_cuda:
xyz_ref_white = xyz_ref_white.cuda()
# scale by CIE XYZ tristimulus values of the reference white point
xyz = xyz / xyz_ref_white.view(1,3,1,1)
# Nonlinear distortion and linear transformation
mask = xyz > 0.008856
xyzm = xyz.clone()
xyz[mask] = torch.pow(xyzm[mask], 1/3)
xyzm = xyz.clone()
xyz[~mask] = 7.787 * xyzm[~mask] + 16. / 116.
x, y, z = xyz[:, 0, :, :], xyz[:, 1, :, :], xyz[:, 2, :, :]
# Vector scaling
L = (116. * y) - 16.
a = 500.0 * (x - y)
b = 200.0 * (y - z)
return torch.stack((L,a,b), 1)
def rgb2lab(rgb, illuminant="D65", observer="2"):
"""RGB to lab color space conversion.
Parameters
----------
rgb : array_like
The image in RGB format, in a 3- or 4-D array of shape
``(.., ..,[ ..,] 3)``.
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
The name of the illuminant (the function is NOT case sensitive).
observer : {"2", "10"}, optional
The aperture angle of the observer.
Returns
-------
out : ndarray
The image in Lab format, in a 3- or 4-D array of shape
``(.., ..,[ ..,] 3)``.
Raises
------
ValueError
If `rgb` is not a 3- or 4-D array of shape ``(.., ..,[ ..,] 3)``.
References
----------
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
Notes
-----
This function uses rgb2xyz and xyz2lab.
By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values
x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for
a list of supported illuminants.
"""
return xyz2lab(rgb2xyz(rgb), illuminant, observer)
def lab2xyz(lab, illuminant="D65", observer="2"):
arr = lab.clone()
L, a, b = arr[:, 0, :, :], arr[:, 1, :, :], arr[:, 2, :, :]
y = (L + 16.) / 116.
x = (a / 500.) + y
z = y - (b / 200.)
# if (z < 0).sum() > 0:
# warn('Color data out of range: Z < 0 in %s pixels' % (z < 0).sum().item())
# z[z < 0] = 0 # NO GRADIENT!!!!
out = torch.stack((x, y, z),1)
mask = out > 0.2068966
outm = out.clone()
out[mask] = torch.pow(outm[mask], 3.)
outm = out.clone()
out[~mask] = (outm[~mask] - 16.0 / 116.) / 7.787
# rescale to the reference white (illuminant)
xyz_ref_white = get_xyz_coords(illuminant, observer)
# cuda
if lab.is_cuda:
xyz_ref_white = xyz_ref_white.cuda()
xyz_ref_white = xyz_ref_white.unsqueeze(2).unsqueeze(2).repeat(1,1,out.shape[2],out.shape[3])
out = out * xyz_ref_white
return out
def xyz2rgb(xyz):
arr = _convert(rgb_from_xyz, xyz)
mask = arr > 0.0031308
arrm = arr.clone()
arr[mask] = 1.055 * torch.pow(arrm[mask], 1 / 2.4) - 0.055
arrm = arr.clone()
arr[~mask] = arrm[~mask] * 12.92
# CLAMP KILLS GRADIENTS
# mask_z = arr < 0
# arr[mask_z] = 0
# mask_o = arr > 1
# arr[mask_o] = 1
# torch.clamp(arr, 0, 1, out=arr)
return arr
def lab2rgb(lab, illuminant="D65", observer="2"):
"""Lab to RGB color space conversion.
Parameters
----------
lab : array_like
The image in Lab format, in a 3-D array of shape ``(.., .., 3)``.
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
The name of the illuminant (the function is NOT case sensitive).
observer : {"2", "10"}, optional
The aperture angle of the observer.
Returns
-------
out : ndarray
The image in RGB format, in a 3-D array of shape ``(.., .., 3)``.
Raises
------
ValueError
If `lab` is not a 3-D array of shape ``(.., .., 3)``.
References
----------
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
Notes
-----
This function uses lab2xyz and xyz2rgb.
By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values
x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for
a list of supported illuminants.
"""
return xyz2rgb(lab2xyz(lab, illuminant, observer))

136
codes/utils/fdpl_util.py Normal file
View File

@ -0,0 +1,136 @@
import numpy as np
import torch
import torch.nn as nn
# note: all dct related functions are either exactly as or based on those
# at https://github.com/zh217/torch-dct
def dct(x, norm=None):
"""
Discrete Cosine Transform, Type II (a.k.a. the DCT)
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last dimension
"""
x_shape = x.shape
N = x_shape[-1]
x = x.contiguous().view(-1, N)
v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
Vc = torch.rfft(v, 1, onesided=False)
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
if norm == 'ortho':
V[:, 0] /= np.sqrt(N) * 2
V[:, 1:] /= np.sqrt(N / 2) * 2
V = 2 * V.view(*x_shape)
return V
def idct(X, norm=None):
"""
The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
Our definition of idct is that idct(dct(x)) == x
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param X: the input signal
:param norm: the normalization, None or 'ortho'
:return: the inverse DCT-II of the signal over the last dimension
"""
x_shape = X.shape
N = x_shape[-1]
X_v = X.contiguous().view(-1, x_shape[-1]) / 2
if norm == 'ortho':
X_v[:, 0] *= np.sqrt(N) * 2
X_v[:, 1:] *= np.sqrt(N / 2) * 2
k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)
V_t_r = X_v
V_t_r = V_t_r.to(device)
V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
V_t_i = V_t_i.to(device)
V_r = V_t_r * W_r - V_t_i * W_i
V_i = V_t_r * W_i + V_t_i * W_r
V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
v = torch.irfft(V, 1, onesided=False)
x = v.new_zeros(v.shape)
x[:, ::2] += v[:, :N - (N // 2)]
x[:, 1::2] += v.flip([1])[:, :N // 2]
return x.view(*x_shape)
def dct_2d(x, norm=None):
"""
2-dimensional Discrete Cosine Transform, Type II (a.k.a. the DCT)
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last 2 dimensions
"""
X1 = dct(x, norm=norm)
X2 = dct(X1.transpose(-1, -2), norm=norm)
return X2.transpose(-1, -2)
def idct_2d(X, norm=None):
"""
The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
Our definition of idct is that idct_2d(dct_2d(x)) == x
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param X: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last 2 dimensions
"""
x1 = idct(X, norm=norm)
x2 = idct(x1.transpose(-1, -2), norm=norm)
return x2.transpose(-1, -2)
def extract_patches_2d(img,patch_shape,step=[1.0,1.0],batch_first=False):
"""
source: https://gist.github.com/dem123456789/23f18fd78ac8da9615c347905e64fc78
"""
patch_H, patch_W = patch_shape[0], patch_shape[1]
if(img.size(2) < patch_H):
num_padded_H_Top = (patch_H - img.size(2))//2
num_padded_H_Bottom = patch_H - img.size(2) - num_padded_H_Top
padding_H = nn.ConstantPad2d((0, 0, num_padded_H_Top, num_padded_H_Bottom), 0)
img = padding_H(img)
if(img.size(3) < patch_W):
num_padded_W_Left = (patch_W - img.size(3))//2
num_padded_W_Right = patch_W - img.size(3) - num_padded_W_Left
padding_W = nn.ConstantPad2d((num_padded_W_Left,num_padded_W_Right, 0, 0), 0)
img = padding_W(img)
step_int = [0, 0]
step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0]
step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1]
patches_fold_H = img.unfold(2, patch_H, step_int[0])
if((img.size(2) - patch_H) % step_int[0] != 0):
patches_fold_H = torch.cat((patches_fold_H,
img[:, :, -patch_H:, :].permute(0,1,3,2).unsqueeze(2)),dim=2)
patches_fold_HW = patches_fold_H.unfold(3, patch_W, step_int[1])
if((img.size(3) - patch_W) % step_int[1] != 0):
patches_fold_HW = torch.cat((patches_fold_HW,
patches_fold_H[:, :, :, -patch_W:, :].permute(0, 1, 2, 4, 3).unsqueeze(3)), dim=3)
patches = patches_fold_HW.permute(2, 3, 0, 1, 4, 5)
patches = patches.reshape(-1, img.size(0), img.size(1), patch_H, patch_W)
if(batch_first):
patches = patches.permute(1, 0, 2, 3, 4)
return patches