forked from mrq/DL-Art-School
Add FDPL Loss
New loss type that can replace PSNR loss. Works against the frequency domain and focuses on frequency features loss during hr->lr conversion.
This commit is contained in:
parent
85ee64b8d9
commit
7629cb0e61
|
@ -172,7 +172,7 @@ class LQGTDataset(data.Dataset):
|
|||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||
img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if self.opt['doResizeLoss']:
|
||||
if 'doResizeLoss' in self.opt.keys() and self.opt['doResizeLoss']:
|
||||
r = random.randrange(0, 10)
|
||||
if r > 5:
|
||||
img_LQ = cv2.resize(img_LQ, (int(LQ_size/2), int(LQ_size/2)), interpolation=cv2.INTER_LINEAR)
|
||||
|
@ -215,7 +215,7 @@ class LQGTDataset(data.Dataset):
|
|||
corruption_buffer.seek(0)
|
||||
img_LQ = Image.open(corruption_buffer)
|
||||
|
||||
if self.opt['grayscale']:
|
||||
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
|
||||
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
|
||||
|
||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
|
|
74
codes/data_scripts/compute_fdpl_perceptual_weights.py
Normal file
74
codes/data_scripts/compute_fdpl_perceptual_weights.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
import torch
|
||||
import os
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import options.options as option
|
||||
from data import create_dataloader, create_dataset
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from torchvision import transforms
|
||||
from utils.fdpl_util import dct_2d, extract_patches_2d
|
||||
import random
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
from utils.colors import rgb2ycbcr
|
||||
import torch.nn.functional as F
|
||||
|
||||
input_config = "../../options/train_imgset_pixgan_srg4_fdpl.yml"
|
||||
output_file = "fdpr_diff_means.pt"
|
||||
device = 'cuda'
|
||||
patch_size=128
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = option.parse(input_config, is_train=True)
|
||||
opt['dist'] = False
|
||||
|
||||
# Create a dataset to load from (this dataset loads HR/LR images and performs any distortions specified by the YML.
|
||||
dataset_opt = opt['datasets']['train']
|
||||
train_set = create_dataset(dataset_opt)
|
||||
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
|
||||
total_iters = int(opt['train']['niter'])
|
||||
total_epochs = int(math.ceil(total_iters / train_size))
|
||||
train_loader = create_dataloader(train_set, dataset_opt, opt, None)
|
||||
print('Number of train images: {:,d}, iters: {:,d}'.format(
|
||||
len(train_set), train_size))
|
||||
|
||||
# calculate the perceptual weights
|
||||
master_diff = np.zeros((patch_size, patch_size))
|
||||
num_patches = 0
|
||||
all_diff_patches = []
|
||||
tq = tqdm(train_loader)
|
||||
sampled = 0
|
||||
for train_data in tq:
|
||||
if sampled > 200:
|
||||
break
|
||||
sampled += 1
|
||||
|
||||
im = rgb2ycbcr(train_data['GT'].double())
|
||||
im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(),
|
||||
size=im.shape[2:],
|
||||
mode="bicubic"))
|
||||
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)
|
||||
patches_hr = dct_2d(patches_hr, norm='ortho')
|
||||
patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True)
|
||||
patches_lr = dct_2d(patches_lr, norm='ortho')
|
||||
b, p, c, w, h = patches_hr.shape
|
||||
diffs = torch.abs(patches_lr - patches_hr) / ((torch.abs(patches_lr) + torch.abs(patches_hr)) / 2 + .00000001)
|
||||
num_patches += b * p
|
||||
all_diff_patches.append(torch.sum(diffs, dim=(0, 1)))
|
||||
|
||||
diff_patches = torch.stack(all_diff_patches, dim=0)
|
||||
diff_means = torch.sum(diff_patches, dim=0) / num_patches
|
||||
|
||||
torch.save(diff_means, output_file)
|
||||
print(diff_means)
|
||||
|
||||
for i in range(3):
|
||||
fig, ax = plt.subplots()
|
||||
divider = make_axes_locatable(ax)
|
||||
cax = divider.append_axes('right', size='5%', pad=0.05)
|
||||
im = ax.imshow(diff_means[i].numpy())
|
||||
ax.set_title("mean_diff for channel %i" % (i,))
|
||||
fig.colorbar(im, cax=cax, orientation='vertical')
|
||||
plt.show()
|
||||
|
|
@ -6,7 +6,7 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel
|
|||
import models.networks as networks
|
||||
import models.lr_scheduler as lr_scheduler
|
||||
from models.base_model import BaseModel
|
||||
from models.loss import GANLoss
|
||||
from models.loss import GANLoss, FDPLLoss
|
||||
from apex import amp
|
||||
from data.weight_scheduler import get_scheduler_for_opt
|
||||
import torch.nn.functional as F
|
||||
|
@ -62,6 +62,16 @@ class SRGANModel(BaseModel):
|
|||
logger.info('Remove pixel loss.')
|
||||
self.cri_pix = None
|
||||
|
||||
# FDPL loss.
|
||||
if 'fdpl_loss' in train_opt.keys():
|
||||
fdpl_opt = train_opt['fdpl_loss']
|
||||
self.fdpl_weight = fdpl_opt['weight']
|
||||
self.fdpl_enabled = self.fdpl_weight > 0
|
||||
if self.fdpl_enabled:
|
||||
self.cri_fdpl = FDPLLoss(fdpl_opt['data_mean'], self.device)
|
||||
else:
|
||||
self.fdpl_enabled = False
|
||||
|
||||
# G feature loss
|
||||
if train_opt['feature_weight'] and train_opt['feature_weight'] > 0:
|
||||
# For backwards compatibility, use a scheduler definition instead. Remove this at some point.
|
||||
|
@ -305,10 +315,14 @@ class SRGANModel(BaseModel):
|
|||
if using_gan_img:
|
||||
l_g_pix_log = None
|
||||
l_g_fea_log = None
|
||||
l_g_fdpl = None
|
||||
if self.cri_pix and not using_gan_img: # pixel loss
|
||||
l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix)
|
||||
l_g_pix_log = l_g_pix / self.l_pix_w
|
||||
l_g_total += l_g_pix
|
||||
if self.fdpl_enabled and not using_gan_img:
|
||||
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
|
||||
l_g_total += l_g_fdpl * self.fdpl_weight
|
||||
if self.cri_fea and not using_gan_img: # feature loss
|
||||
real_fea = self.netF(pix).detach()
|
||||
fake_fea = self.netF(fea_GenOut)
|
||||
|
@ -535,6 +549,8 @@ class SRGANModel(BaseModel):
|
|||
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
||||
if self.cri_pix and l_g_pix_log is not None:
|
||||
self.add_log_entry('l_g_pix', l_g_pix_log.item())
|
||||
if self.fdpl_enabled and l_g_fdpl is not None:
|
||||
self.add_log_entry('l_g_fdpl', l_g_fdpl.item())
|
||||
if self.cri_fea and l_g_fea_log is not None:
|
||||
self.add_log_entry('feature_weight', fea_w)
|
||||
self.add_log_entry('l_g_fea', l_g_fea_log.item())
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from utils.fdpl_util import extract_patches_2d, dct_2d
|
||||
from utils.colors import rgb2ycbcr
|
||||
|
||||
|
||||
class CharbonnierLoss(nn.Module):
|
||||
|
@ -75,3 +78,99 @@ class GradientPenaltyLoss(nn.Module):
|
|||
|
||||
loss = ((grad_interp_norm - 1)**2).mean()
|
||||
return loss
|
||||
|
||||
|
||||
# Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL
|
||||
# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see data_scripts/compute_fdpl_perceptual_weights.py
|
||||
# In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs).
|
||||
class FDPLLoss(nn.Module):
|
||||
"""
|
||||
Loss function taking the MSE between the 2D DCT coefficients
|
||||
of predicted and target images or an image channel.
|
||||
DCT coefficients are computed for each 8x8 block of the image
|
||||
|
||||
Important note about this loss: Since it operates in the frequency domain, precision is highly important.
|
||||
It works on FP64 numbers and will fail if you attempt to use it with AMP. Recommend you split this loss
|
||||
off from the rest of amp.scale_loss().
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_diff_means_file, device):
|
||||
"""
|
||||
dataset_diff_means (torch.tensor): Pre-computed frequency-domain mean differences between LR and HR images.
|
||||
"""
|
||||
# These values are derived from the JPEG standard.
|
||||
qt_Y = torch.tensor([[16, 11, 10, 16, 24, 40, 51, 61],
|
||||
[12, 12, 14, 19, 26, 58, 60, 55],
|
||||
[14, 13, 16, 24, 40, 57, 69, 56],
|
||||
[14, 17, 22, 29, 51, 87, 80, 62],
|
||||
[18, 22, 37, 56, 68, 109, 103, 77],
|
||||
[24, 35, 55, 64, 81, 104, 113, 92],
|
||||
[49, 64, 78, 87, 103, 121, 120, 101],
|
||||
[72, 92, 95, 98, 112, 100, 103, 99]],
|
||||
dtype=torch.double,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
qt_C = torch.tensor([[17, 18, 24, 47, 99, 99, 99, 99],
|
||||
[18, 21, 26, 66, 99, 99, 99, 99],
|
||||
[24, 26, 56, 99, 99, 99, 99, 99],
|
||||
[47, 66, 99, 99, 99, 99, 99, 99],
|
||||
[99, 99, 99, 99, 99, 99, 99, 99],
|
||||
[99, 99, 99, 99, 99, 99, 99, 99],
|
||||
[99, 99, 99, 99, 99, 99, 99, 99],
|
||||
[99, 99, 99, 99, 99, 99, 99, 99]],
|
||||
dtype=torch.double,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
"""
|
||||
Reasoning behind this perceptual weight matrix: JPEG gives as a model of frequencies that are important
|
||||
for human perception. In that model, lower frequencies are more important than higher frequencies. Because
|
||||
of this, the higher frequencies are the first to go during compression. As compression increases, the affect
|
||||
spreads to the lower frequencies, which degrades perceptual quality. But when the lower frequencies are
|
||||
preserved, JPEG does an excellent job of compression without a noticeable loss of quality.
|
||||
In super resolution, we already have the low frequencies. In fact that is really all we have in the low
|
||||
resolution images.
|
||||
As evidenced by the diff_means matrix above, what is lost in the SR images is the mid-range frequencies;
|
||||
those across and towards the centre of the diagonal. We can bias our model to recover these frequencies
|
||||
by having our loss function prioritize these coefficients, with priority determined by the magnitude of
|
||||
relative change between the low-res and high-res images. But we can take this further and into a true
|
||||
preceptual loss by further prioritizing DCT coefficients by the importance that has been assigned to them
|
||||
by the JPEG quantization table. That is how the table below is created.
|
||||
|
||||
The problem is that we don't know if the JPEG model is optimal. So there is room for qualitative evaluation
|
||||
of the quantization table values. We can further our perspective weights deleting each in turn for a small
|
||||
set of images and evaluating the resulting change in percieved quality. I can do this on my own to start and
|
||||
if it works, I can do a small user study to determine this.
|
||||
"""
|
||||
diff_means = torch.tensor(torch.load(dataset_diff_means_file), device=device)
|
||||
perceptual_weights = torch.stack([(torch.ones_like(qt_Y, device=device) / qt_Y),
|
||||
(torch.ones_like(qt_C, device=device) / qt_C),
|
||||
(torch.ones_like(qt_C, device=device) / qt_C)])
|
||||
perceptual_weights = perceptual_weights * diff_means
|
||||
self.perceptual_weights = perceptual_weights / torch.mean(perceptual_weights)
|
||||
super(FDPLLoss, self).__init__()
|
||||
|
||||
def forward(self, predictions, targets):
|
||||
"""
|
||||
Args:
|
||||
predictions (torch.tensor): output of an image transformation model.
|
||||
shape: batch_size x 3 x H x W
|
||||
targets (torch.tensor): ground truth images corresponding to outputs
|
||||
shape: batch_size x 3 x H x W
|
||||
criterion (torch.nn.MSELoss): object used to calculate MSE
|
||||
|
||||
Returns:
|
||||
loss (float): MSE between predicted and ground truth 2D DCT coefficients
|
||||
"""
|
||||
# transition to fp64 and then convert to YCC color space.
|
||||
predictions = rgb2ycbcr(predictions.double())
|
||||
targets = rgb2ycbcr(targets.double())
|
||||
|
||||
# get DCT coefficients of ground truth patches
|
||||
patches = extract_patches_2d(img=targets, patch_shape=(8, 8), batch_first=True)
|
||||
ground_truth_dct = dct_2d(patches, norm='ortho')
|
||||
|
||||
# get DCT coefficients of transformed images
|
||||
patches = extract_patches_2d(img=predictions, patch_shape=(8, 8), batch_first=True)
|
||||
outputs_dct = dct_2d(patches, norm='ortho')
|
||||
loss = torch.sum(((outputs_dct - ground_truth_dct).pow(2)) * self.perceptual_weights)
|
||||
return loss
|
409
codes/utils/colors.py
Normal file
409
codes/utils/colors.py
Normal file
|
@ -0,0 +1,409 @@
|
|||
# Differentiable color conversions from https://github.com/TheZino/pytorch-color-conversions
|
||||
|
||||
from warnings import warn
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy import linalg
|
||||
|
||||
xyz_from_rgb = torch.Tensor([[0.412453, 0.357580, 0.180423],
|
||||
[0.212671, 0.715160, 0.072169],
|
||||
[0.019334, 0.119193, 0.950227]])
|
||||
|
||||
rgb_from_xyz = torch.Tensor(linalg.inv(xyz_from_rgb))
|
||||
|
||||
illuminants = \
|
||||
{"A": {'2': torch.Tensor([(1.098466069456375, 1, 0.3558228003436005)]),
|
||||
'10': torch.Tensor([(1.111420406956693, 1, 0.3519978321919493)])},
|
||||
"D50": {'2': torch.Tensor([(0.9642119944211994, 1, 0.8251882845188288)]),
|
||||
'10': torch.Tensor([(0.9672062750333777, 1, 0.8142801513128616)])},
|
||||
"D55": {'2': torch.Tensor([(0.956797052643698, 1, 0.9214805860173273)]),
|
||||
'10': torch.Tensor([(0.9579665682254781, 1, 0.9092525159847462)])},
|
||||
"D65": {'2': torch.Tensor([(0.95047, 1., 1.08883)]), # This was: `lab_ref_white`
|
||||
'10': torch.Tensor([(0.94809667673716, 1, 1.0730513595166162)])},
|
||||
"D75": {'2': torch.Tensor([(0.9497220898840717, 1, 1.226393520724154)]),
|
||||
'10': torch.Tensor([(0.9441713925645873, 1, 1.2064272211720228)])},
|
||||
"E": {'2': torch.Tensor([(1.0, 1.0, 1.0)]),
|
||||
'10': torch.Tensor([(1.0, 1.0, 1.0)])}}
|
||||
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# The conversion functions that make use of the matrices above
|
||||
# -------------------------------------------------------------
|
||||
|
||||
|
||||
##### RGB - YCbCr
|
||||
|
||||
# Helper for the creation of module-global constant tensors
|
||||
def _t(data):
|
||||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # TODO inherit this
|
||||
device = torch.device("cpu") # TODO inherit this
|
||||
return torch.tensor(data, requires_grad=False, dtype=torch.float32, device=device)
|
||||
|
||||
# Helper for color matrix multiplication
|
||||
def _mul(coeffs, image):
|
||||
# This is implementation is clearly suboptimal. The function will
|
||||
# be implemented with 'einsum' when a bug in pytorch 0.4.0 will be
|
||||
# fixed (Einsum modifies variables in-place #7763).
|
||||
coeffs = coeffs.to(image.device)
|
||||
r0 = image[:, 0:1, :, :].repeat(1, 3, 1, 1) * coeffs[:, 0].view(1, 3, 1, 1)
|
||||
r1 = image[:, 1:2, :, :].repeat(1, 3, 1, 1) * coeffs[:, 1].view(1, 3, 1, 1)
|
||||
r2 = image[:, 2:3, :, :].repeat(1, 3, 1, 1) * coeffs[:, 2].view(1, 3, 1, 1)
|
||||
return r0 + r1 + r2
|
||||
# return torch.einsum("dc,bcij->bdij", (coeffs.to(image.device), image))
|
||||
|
||||
_RGB_TO_YCBCR = _t([[0.257, 0.504, 0.098], [-0.148, -0.291, 0.439], [0.439 , -0.368, -0.071]])
|
||||
_YCBCR_OFF = _t([0.063, 0.502, 0.502]).view(1, 3, 1, 1)
|
||||
|
||||
|
||||
def rgb2ycbcr(rgb):
|
||||
"""sRGB to YCbCr conversion."""
|
||||
clip_rgb=False
|
||||
if clip_rgb:
|
||||
rgb = torch.clamp(rgb, 0, 1)
|
||||
return _mul(_RGB_TO_YCBCR, rgb) + _YCBCR_OFF.to(rgb.device)
|
||||
|
||||
|
||||
def ycbcr2rgb(rgb):
|
||||
"""YCbCr to sRGB conversion."""
|
||||
clip_rgb=False
|
||||
rgb = _mul(torch.inverse(_RGB_TO_YCBCR), rgb - _YCBCR_OFF.to(rgb.device))
|
||||
if clip_rgb:
|
||||
rgb = torch.clamp(rgb, 0, 1)
|
||||
return rgb
|
||||
|
||||
|
||||
##### HSV - RGB
|
||||
|
||||
def rgb2hsv(rgb):
|
||||
"""
|
||||
R, G and B input range = 0 ÷ 1.0
|
||||
H, S and V output range = 0 ÷ 1.0
|
||||
"""
|
||||
eps = 1e-7
|
||||
|
||||
var_R = rgb[:,0,:,:]
|
||||
var_G = rgb[:,1,:,:]
|
||||
var_B = rgb[:,2,:,:]
|
||||
|
||||
var_Min = rgb.min(1)[0] #Min. value of RGB
|
||||
var_Max = rgb.max(1)[0] #Max. value of RGB
|
||||
del_Max = var_Max - var_Min ##Delta RGB value
|
||||
|
||||
H = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
|
||||
S = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
|
||||
V = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
|
||||
|
||||
V = var_Max
|
||||
|
||||
#This is a gray, no chroma...
|
||||
mask = del_Max == 0
|
||||
H[mask] = 0
|
||||
S[mask] = 0
|
||||
|
||||
#Chromatic data...
|
||||
S = del_Max / (var_Max + eps)
|
||||
|
||||
del_R = ( ( ( var_Max - var_R ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
|
||||
del_G = ( ( ( var_Max - var_G ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
|
||||
del_B = ( ( ( var_Max - var_B ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
|
||||
|
||||
H = torch.where( var_R == var_Max , del_B - del_G, H)
|
||||
H = torch.where( var_G == var_Max , ( 1 / 3 ) + del_R - del_B, H)
|
||||
H = torch.where( var_B == var_Max ,( 2 / 3 ) + del_G - del_R, H)
|
||||
|
||||
# if ( H < 0 ) H += 1
|
||||
# if ( H > 1 ) H -= 1
|
||||
|
||||
return torch.stack([H, S, V], 1)
|
||||
|
||||
def hsv2rgb(hsv):
|
||||
"""
|
||||
H, S and V input range = 0 ÷ 1.0
|
||||
R, G and B output range = 0 ÷ 1.0
|
||||
"""
|
||||
|
||||
eps = 1e-7
|
||||
|
||||
bb,cc,hh,ww = hsv.shape
|
||||
H = hsv[:,0,:,:]
|
||||
S = hsv[:,1,:,:]
|
||||
V = hsv[:,2,:,:]
|
||||
|
||||
# var_h = torch.zeros(bb,hh,ww)
|
||||
# var_s = torch.zeros(bb,hh,ww)
|
||||
# var_v = torch.zeros(bb,hh,ww)
|
||||
|
||||
# var_r = torch.zeros(bb,hh,ww)
|
||||
# var_g = torch.zeros(bb,hh,ww)
|
||||
# var_b = torch.zeros(bb,hh,ww)
|
||||
|
||||
# Grayscale
|
||||
if (S == 0).all():
|
||||
|
||||
R = V
|
||||
G = V
|
||||
B = V
|
||||
|
||||
# Chromatic data
|
||||
else:
|
||||
|
||||
var_h = H * 6
|
||||
|
||||
var_h[var_h == 6] = 0 #H must be < 1
|
||||
var_i = var_h.floor() #Or ... var_i = floor( var_h )
|
||||
var_1 = V * ( 1 - S )
|
||||
var_2 = V * ( 1 - S * ( var_h - var_i ) )
|
||||
var_3 = V * ( 1 - S * ( 1 - ( var_h - var_i ) ) )
|
||||
|
||||
# else { var_r = V ; var_g = var_1 ; var_b = var_2 }
|
||||
var_r = V
|
||||
var_g = var_1
|
||||
var_b = var_2
|
||||
|
||||
# var_i == 0 { var_r = V ; var_g = var_3 ; var_b = var_1 }
|
||||
var_r = torch.where(var_i == 0, V, var_r)
|
||||
var_g = torch.where(var_i == 0, var_3, var_g)
|
||||
var_b = torch.where(var_i == 0, var_1, var_b)
|
||||
|
||||
# else if ( var_i == 1 ) { var_r = var_2 ; var_g = V ; var_b = var_1 }
|
||||
var_r = torch.where(var_i == 1, var_2, var_r)
|
||||
var_g = torch.where(var_i == 1, V, var_g)
|
||||
var_b = torch.where(var_i == 1, var_1, var_b)
|
||||
|
||||
# else if ( var_i == 2 ) { var_r = var_1 ; var_g = V ; var_b = var_3 }
|
||||
var_r = torch.where(var_i == 2, var_1, var_r)
|
||||
var_g = torch.where(var_i == 2, V, var_g)
|
||||
var_b = torch.where(var_i == 2, var_3, var_b)
|
||||
|
||||
# else if ( var_i == 3 ) { var_r = var_1 ; var_g = var_2 ; var_b = V }
|
||||
var_r = torch.where(var_i == 3, var_1, var_r)
|
||||
var_g = torch.where(var_i == 3, var_2, var_g)
|
||||
var_b = torch.where(var_i == 3, V, var_b)
|
||||
|
||||
# else if ( var_i == 4 ) { var_r = var_3 ; var_g = var_1 ; var_b = V }
|
||||
var_r = torch.where(var_i == 4, var_3, var_r)
|
||||
var_g = torch.where(var_i == 4, var_1, var_g)
|
||||
var_b = torch.where(var_i == 4, V, var_b)
|
||||
|
||||
|
||||
R = var_r #* 255
|
||||
G = var_g #* 255
|
||||
B = var_b #* 255
|
||||
|
||||
|
||||
return torch.stack([R, G, B], 1)
|
||||
|
||||
|
||||
##### LAB - RGB
|
||||
|
||||
def _convert(matrix, arr):
|
||||
"""Do the color space conversion.
|
||||
Parameters
|
||||
----------
|
||||
matrix : array_like
|
||||
The 3x3 matrix to use.
|
||||
arr : array_like
|
||||
The input array.
|
||||
Returns
|
||||
-------
|
||||
out : ndarray, dtype=float
|
||||
The converted array.
|
||||
"""
|
||||
|
||||
if arr.is_cuda:
|
||||
matrix = matrix.cuda()
|
||||
|
||||
bs, ch, h, w = arr.shape
|
||||
|
||||
arr = arr.permute((0,2,3,1))
|
||||
arr = arr.contiguous().view(-1,1,3)
|
||||
|
||||
matrix = matrix.transpose(0,1).unsqueeze(0)
|
||||
matrix = matrix.repeat(arr.shape[0],1,1)
|
||||
|
||||
res = torch.bmm(arr,matrix)
|
||||
|
||||
res = res.view(bs,h,w,ch)
|
||||
res = res.transpose(3,2).transpose(2,1)
|
||||
|
||||
|
||||
return res
|
||||
|
||||
def get_xyz_coords(illuminant, observer):
|
||||
"""Get the XYZ coordinates of the given illuminant and observer [1]_.
|
||||
Parameters
|
||||
----------
|
||||
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
|
||||
The name of the illuminant (the function is NOT case sensitive).
|
||||
observer : {"2", "10"}, optional
|
||||
The aperture angle of the observer.
|
||||
Returns
|
||||
-------
|
||||
(x, y, z) : tuple
|
||||
A tuple with 3 elements containing the XYZ coordinates of the given
|
||||
illuminant.
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If either the illuminant or the observer angle are not supported or
|
||||
unknown.
|
||||
References
|
||||
----------
|
||||
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
|
||||
"""
|
||||
illuminant = illuminant.upper()
|
||||
try:
|
||||
return illuminants[illuminant][observer]
|
||||
except KeyError:
|
||||
raise ValueError("Unknown illuminant/observer combination\
|
||||
(\'{0}\', \'{1}\')".format(illuminant, observer))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def rgb2xyz(rgb):
|
||||
|
||||
mask = rgb > 0.04045
|
||||
rgbm = rgb.clone()
|
||||
tmp = torch.pow((rgb + 0.055) / 1.055, 2.4)
|
||||
rgb = torch.where(mask, tmp, rgb)
|
||||
|
||||
rgbm = rgb.clone()
|
||||
rgb[~mask] = rgbm[~mask]/12.92
|
||||
return _convert(xyz_from_rgb, rgb)
|
||||
|
||||
def xyz2lab(xyz, illuminant="D65", observer="2"):
|
||||
|
||||
# arr = _prepare_colorarray(xyz)
|
||||
xyz_ref_white = get_xyz_coords(illuminant, observer)
|
||||
#cuda
|
||||
if xyz.is_cuda:
|
||||
xyz_ref_white = xyz_ref_white.cuda()
|
||||
|
||||
# scale by CIE XYZ tristimulus values of the reference white point
|
||||
xyz = xyz / xyz_ref_white.view(1,3,1,1)
|
||||
# Nonlinear distortion and linear transformation
|
||||
mask = xyz > 0.008856
|
||||
xyzm = xyz.clone()
|
||||
xyz[mask] = torch.pow(xyzm[mask], 1/3)
|
||||
xyzm = xyz.clone()
|
||||
xyz[~mask] = 7.787 * xyzm[~mask] + 16. / 116.
|
||||
x, y, z = xyz[:, 0, :, :], xyz[:, 1, :, :], xyz[:, 2, :, :]
|
||||
# Vector scaling
|
||||
L = (116. * y) - 16.
|
||||
a = 500.0 * (x - y)
|
||||
b = 200.0 * (y - z)
|
||||
return torch.stack((L,a,b), 1)
|
||||
|
||||
def rgb2lab(rgb, illuminant="D65", observer="2"):
|
||||
"""RGB to lab color space conversion.
|
||||
Parameters
|
||||
----------
|
||||
rgb : array_like
|
||||
The image in RGB format, in a 3- or 4-D array of shape
|
||||
``(.., ..,[ ..,] 3)``.
|
||||
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
|
||||
The name of the illuminant (the function is NOT case sensitive).
|
||||
observer : {"2", "10"}, optional
|
||||
The aperture angle of the observer.
|
||||
Returns
|
||||
-------
|
||||
out : ndarray
|
||||
The image in Lab format, in a 3- or 4-D array of shape
|
||||
``(.., ..,[ ..,] 3)``.
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `rgb` is not a 3- or 4-D array of shape ``(.., ..,[ ..,] 3)``.
|
||||
References
|
||||
----------
|
||||
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
|
||||
Notes
|
||||
-----
|
||||
This function uses rgb2xyz and xyz2lab.
|
||||
By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values
|
||||
x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for
|
||||
a list of supported illuminants.
|
||||
"""
|
||||
return xyz2lab(rgb2xyz(rgb), illuminant, observer)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def lab2xyz(lab, illuminant="D65", observer="2"):
|
||||
arr = lab.clone()
|
||||
L, a, b = arr[:, 0, :, :], arr[:, 1, :, :], arr[:, 2, :, :]
|
||||
y = (L + 16.) / 116.
|
||||
x = (a / 500.) + y
|
||||
z = y - (b / 200.)
|
||||
|
||||
# if (z < 0).sum() > 0:
|
||||
# warn('Color data out of range: Z < 0 in %s pixels' % (z < 0).sum().item())
|
||||
# z[z < 0] = 0 # NO GRADIENT!!!!
|
||||
|
||||
out = torch.stack((x, y, z),1)
|
||||
|
||||
mask = out > 0.2068966
|
||||
outm = out.clone()
|
||||
out[mask] = torch.pow(outm[mask], 3.)
|
||||
outm = out.clone()
|
||||
out[~mask] = (outm[~mask] - 16.0 / 116.) / 7.787
|
||||
|
||||
# rescale to the reference white (illuminant)
|
||||
xyz_ref_white = get_xyz_coords(illuminant, observer)
|
||||
# cuda
|
||||
if lab.is_cuda:
|
||||
xyz_ref_white = xyz_ref_white.cuda()
|
||||
xyz_ref_white = xyz_ref_white.unsqueeze(2).unsqueeze(2).repeat(1,1,out.shape[2],out.shape[3])
|
||||
out = out * xyz_ref_white
|
||||
return out
|
||||
|
||||
def xyz2rgb(xyz):
|
||||
arr = _convert(rgb_from_xyz, xyz)
|
||||
mask = arr > 0.0031308
|
||||
arrm = arr.clone()
|
||||
arr[mask] = 1.055 * torch.pow(arrm[mask], 1 / 2.4) - 0.055
|
||||
arrm = arr.clone()
|
||||
arr[~mask] = arrm[~mask] * 12.92
|
||||
|
||||
# CLAMP KILLS GRADIENTS
|
||||
# mask_z = arr < 0
|
||||
# arr[mask_z] = 0
|
||||
# mask_o = arr > 1
|
||||
# arr[mask_o] = 1
|
||||
|
||||
# torch.clamp(arr, 0, 1, out=arr)
|
||||
return arr
|
||||
|
||||
def lab2rgb(lab, illuminant="D65", observer="2"):
|
||||
"""Lab to RGB color space conversion.
|
||||
Parameters
|
||||
----------
|
||||
lab : array_like
|
||||
The image in Lab format, in a 3-D array of shape ``(.., .., 3)``.
|
||||
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
|
||||
The name of the illuminant (the function is NOT case sensitive).
|
||||
observer : {"2", "10"}, optional
|
||||
The aperture angle of the observer.
|
||||
Returns
|
||||
-------
|
||||
out : ndarray
|
||||
The image in RGB format, in a 3-D array of shape ``(.., .., 3)``.
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `lab` is not a 3-D array of shape ``(.., .., 3)``.
|
||||
References
|
||||
----------
|
||||
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
|
||||
Notes
|
||||
-----
|
||||
This function uses lab2xyz and xyz2rgb.
|
||||
By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values
|
||||
x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for
|
||||
a list of supported illuminants.
|
||||
"""
|
||||
return xyz2rgb(lab2xyz(lab, illuminant, observer))
|
136
codes/utils/fdpl_util.py
Normal file
136
codes/utils/fdpl_util.py
Normal file
|
@ -0,0 +1,136 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# note: all dct related functions are either exactly as or based on those
|
||||
# at https://github.com/zh217/torch-dct
|
||||
def dct(x, norm=None):
|
||||
"""
|
||||
Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
||||
For the meaning of the parameter `norm`, see:
|
||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||
:param x: the input signal
|
||||
:param norm: the normalization, None or 'ortho'
|
||||
:return: the DCT-II of the signal over the last dimension
|
||||
"""
|
||||
x_shape = x.shape
|
||||
N = x_shape[-1]
|
||||
x = x.contiguous().view(-1, N)
|
||||
|
||||
v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
|
||||
|
||||
Vc = torch.rfft(v, 1, onesided=False)
|
||||
|
||||
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
|
||||
W_r = torch.cos(k)
|
||||
W_i = torch.sin(k)
|
||||
|
||||
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
|
||||
|
||||
if norm == 'ortho':
|
||||
V[:, 0] /= np.sqrt(N) * 2
|
||||
V[:, 1:] /= np.sqrt(N / 2) * 2
|
||||
|
||||
V = 2 * V.view(*x_shape)
|
||||
|
||||
return V
|
||||
|
||||
def idct(X, norm=None):
|
||||
"""
|
||||
The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
||||
Our definition of idct is that idct(dct(x)) == x
|
||||
For the meaning of the parameter `norm`, see:
|
||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||
:param X: the input signal
|
||||
:param norm: the normalization, None or 'ortho'
|
||||
:return: the inverse DCT-II of the signal over the last dimension
|
||||
"""
|
||||
|
||||
x_shape = X.shape
|
||||
N = x_shape[-1]
|
||||
|
||||
X_v = X.contiguous().view(-1, x_shape[-1]) / 2
|
||||
|
||||
if norm == 'ortho':
|
||||
X_v[:, 0] *= np.sqrt(N) * 2
|
||||
X_v[:, 1:] *= np.sqrt(N / 2) * 2
|
||||
|
||||
k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
|
||||
W_r = torch.cos(k)
|
||||
W_i = torch.sin(k)
|
||||
|
||||
V_t_r = X_v
|
||||
V_t_r = V_t_r.to(device)
|
||||
V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
|
||||
V_t_i = V_t_i.to(device)
|
||||
|
||||
V_r = V_t_r * W_r - V_t_i * W_i
|
||||
V_i = V_t_r * W_i + V_t_i * W_r
|
||||
|
||||
V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
|
||||
|
||||
v = torch.irfft(V, 1, onesided=False)
|
||||
x = v.new_zeros(v.shape)
|
||||
x[:, ::2] += v[:, :N - (N // 2)]
|
||||
x[:, 1::2] += v.flip([1])[:, :N // 2]
|
||||
|
||||
return x.view(*x_shape)
|
||||
|
||||
def dct_2d(x, norm=None):
|
||||
"""
|
||||
2-dimensional Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
||||
For the meaning of the parameter `norm`, see:
|
||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||
:param x: the input signal
|
||||
:param norm: the normalization, None or 'ortho'
|
||||
:return: the DCT-II of the signal over the last 2 dimensions
|
||||
"""
|
||||
X1 = dct(x, norm=norm)
|
||||
X2 = dct(X1.transpose(-1, -2), norm=norm)
|
||||
return X2.transpose(-1, -2)
|
||||
|
||||
def idct_2d(X, norm=None):
|
||||
"""
|
||||
The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
||||
Our definition of idct is that idct_2d(dct_2d(x)) == x
|
||||
For the meaning of the parameter `norm`, see:
|
||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||
:param X: the input signal
|
||||
:param norm: the normalization, None or 'ortho'
|
||||
:return: the DCT-II of the signal over the last 2 dimensions
|
||||
"""
|
||||
x1 = idct(X, norm=norm)
|
||||
x2 = idct(x1.transpose(-1, -2), norm=norm)
|
||||
return x2.transpose(-1, -2)
|
||||
|
||||
def extract_patches_2d(img,patch_shape,step=[1.0,1.0],batch_first=False):
|
||||
"""
|
||||
source: https://gist.github.com/dem123456789/23f18fd78ac8da9615c347905e64fc78
|
||||
"""
|
||||
patch_H, patch_W = patch_shape[0], patch_shape[1]
|
||||
if(img.size(2) < patch_H):
|
||||
num_padded_H_Top = (patch_H - img.size(2))//2
|
||||
num_padded_H_Bottom = patch_H - img.size(2) - num_padded_H_Top
|
||||
padding_H = nn.ConstantPad2d((0, 0, num_padded_H_Top, num_padded_H_Bottom), 0)
|
||||
img = padding_H(img)
|
||||
if(img.size(3) < patch_W):
|
||||
num_padded_W_Left = (patch_W - img.size(3))//2
|
||||
num_padded_W_Right = patch_W - img.size(3) - num_padded_W_Left
|
||||
padding_W = nn.ConstantPad2d((num_padded_W_Left,num_padded_W_Right, 0, 0), 0)
|
||||
img = padding_W(img)
|
||||
step_int = [0, 0]
|
||||
step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0]
|
||||
step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1]
|
||||
patches_fold_H = img.unfold(2, patch_H, step_int[0])
|
||||
if((img.size(2) - patch_H) % step_int[0] != 0):
|
||||
patches_fold_H = torch.cat((patches_fold_H,
|
||||
img[:, :, -patch_H:, :].permute(0,1,3,2).unsqueeze(2)),dim=2)
|
||||
patches_fold_HW = patches_fold_H.unfold(3, patch_W, step_int[1])
|
||||
if((img.size(3) - patch_W) % step_int[1] != 0):
|
||||
patches_fold_HW = torch.cat((patches_fold_HW,
|
||||
patches_fold_H[:, :, :, -patch_W:, :].permute(0, 1, 2, 4, 3).unsqueeze(3)), dim=3)
|
||||
patches = patches_fold_HW.permute(2, 3, 0, 1, 4, 5)
|
||||
patches = patches.reshape(-1, img.size(0), img.size(1), patch_H, patch_W)
|
||||
if(batch_first):
|
||||
patches = patches.permute(1, 0, 2, 3, 4)
|
||||
return patches
|
Loading…
Reference in New Issue
Block a user