diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index 5cc44b07..8d0ec157 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -172,7 +172,7 @@ class LQGTDataset(data.Dataset): img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) - if self.opt['doResizeLoss']: + if 'doResizeLoss' in self.opt.keys() and self.opt['doResizeLoss']: r = random.randrange(0, 10) if r > 5: img_LQ = cv2.resize(img_LQ, (int(LQ_size/2), int(LQ_size/2)), interpolation=cv2.INTER_LINEAR) @@ -215,7 +215,7 @@ class LQGTDataset(data.Dataset): corruption_buffer.seek(0) img_LQ = Image.open(corruption_buffer) - if self.opt['grayscale']: + if 'grayscale' in self.opt.keys() and self.opt['grayscale']: img_LQ = ImageOps.grayscale(img_LQ).convert('RGB') img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() diff --git a/codes/data_scripts/compute_fdpl_perceptual_weights.py b/codes/data_scripts/compute_fdpl_perceptual_weights.py new file mode 100644 index 00000000..ebffbf13 --- /dev/null +++ b/codes/data_scripts/compute_fdpl_perceptual_weights.py @@ -0,0 +1,74 @@ +import torch +import os +from PIL import Image +import numpy as np +import options.options as option +from data import create_dataloader, create_dataset +import math +from tqdm import tqdm +from torchvision import transforms +from utils.fdpl_util import dct_2d, extract_patches_2d +import random +import matplotlib.pyplot as plt +from mpl_toolkits.axes_grid1 import make_axes_locatable +from utils.colors import rgb2ycbcr +import torch.nn.functional as F + +input_config = "../../options/train_imgset_pixgan_srg4_fdpl.yml" +output_file = "fdpr_diff_means.pt" +device = 'cuda' +patch_size=128 + +if __name__ == '__main__': + opt = option.parse(input_config, is_train=True) + opt['dist'] = False + + # Create a dataset to load from (this dataset loads HR/LR images and performs any distortions specified by the YML. + dataset_opt = opt['datasets']['train'] + train_set = create_dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) + total_iters = int(opt['train']['niter']) + total_epochs = int(math.ceil(total_iters / train_size)) + train_loader = create_dataloader(train_set, dataset_opt, opt, None) + print('Number of train images: {:,d}, iters: {:,d}'.format( + len(train_set), train_size)) + + # calculate the perceptual weights + master_diff = np.zeros((patch_size, patch_size)) + num_patches = 0 + all_diff_patches = [] + tq = tqdm(train_loader) + sampled = 0 + for train_data in tq: + if sampled > 200: + break + sampled += 1 + + im = rgb2ycbcr(train_data['GT'].double()) + im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(), + size=im.shape[2:], + mode="bicubic")) + patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True) + patches_hr = dct_2d(patches_hr, norm='ortho') + patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True) + patches_lr = dct_2d(patches_lr, norm='ortho') + b, p, c, w, h = patches_hr.shape + diffs = torch.abs(patches_lr - patches_hr) / ((torch.abs(patches_lr) + torch.abs(patches_hr)) / 2 + .00000001) + num_patches += b * p + all_diff_patches.append(torch.sum(diffs, dim=(0, 1))) + + diff_patches = torch.stack(all_diff_patches, dim=0) + diff_means = torch.sum(diff_patches, dim=0) / num_patches + + torch.save(diff_means, output_file) + print(diff_means) + + for i in range(3): + fig, ax = plt.subplots() + divider = make_axes_locatable(ax) + cax = divider.append_axes('right', size='5%', pad=0.05) + im = ax.imshow(diff_means[i].numpy()) + ax.set_title("mean_diff for channel %i" % (i,)) + fig.colorbar(im, cax=cax, orientation='vertical') + plt.show() + diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 6a75acd0..91f20bb1 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -6,7 +6,7 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel import models.networks as networks import models.lr_scheduler as lr_scheduler from models.base_model import BaseModel -from models.loss import GANLoss +from models.loss import GANLoss, FDPLLoss from apex import amp from data.weight_scheduler import get_scheduler_for_opt import torch.nn.functional as F @@ -62,6 +62,16 @@ class SRGANModel(BaseModel): logger.info('Remove pixel loss.') self.cri_pix = None + # FDPL loss. + if 'fdpl_loss' in train_opt.keys(): + fdpl_opt = train_opt['fdpl_loss'] + self.fdpl_weight = fdpl_opt['weight'] + self.fdpl_enabled = self.fdpl_weight > 0 + if self.fdpl_enabled: + self.cri_fdpl = FDPLLoss(fdpl_opt['data_mean'], self.device) + else: + self.fdpl_enabled = False + # G feature loss if train_opt['feature_weight'] and train_opt['feature_weight'] > 0: # For backwards compatibility, use a scheduler definition instead. Remove this at some point. @@ -305,10 +315,14 @@ class SRGANModel(BaseModel): if using_gan_img: l_g_pix_log = None l_g_fea_log = None + l_g_fdpl = None if self.cri_pix and not using_gan_img: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix) l_g_pix_log = l_g_pix / self.l_pix_w l_g_total += l_g_pix + if self.fdpl_enabled and not using_gan_img: + l_g_fdpl = self.cri_fdpl(fea_GenOut, pix) + l_g_total += l_g_fdpl * self.fdpl_weight if self.cri_fea and not using_gan_img: # feature loss real_fea = self.netF(pix).detach() fake_fea = self.netF(fea_GenOut) @@ -535,6 +549,8 @@ class SRGANModel(BaseModel): if step % self.D_update_ratio == 0 and step >= self.D_init_iters: if self.cri_pix and l_g_pix_log is not None: self.add_log_entry('l_g_pix', l_g_pix_log.item()) + if self.fdpl_enabled and l_g_fdpl is not None: + self.add_log_entry('l_g_fdpl', l_g_fdpl.item()) if self.cri_fea and l_g_fea_log is not None: self.add_log_entry('feature_weight', fea_w) self.add_log_entry('l_g_fea', l_g_fea_log.item()) diff --git a/codes/models/loss.py b/codes/models/loss.py index 9334f806..6698ad2b 100644 --- a/codes/models/loss.py +++ b/codes/models/loss.py @@ -1,5 +1,8 @@ import torch import torch.nn as nn +import numpy as np +from utils.fdpl_util import extract_patches_2d, dct_2d +from utils.colors import rgb2ycbcr class CharbonnierLoss(nn.Module): @@ -75,3 +78,99 @@ class GradientPenaltyLoss(nn.Module): loss = ((grad_interp_norm - 1)**2).mean() return loss + + +# Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL +# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see data_scripts/compute_fdpl_perceptual_weights.py +# In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs). +class FDPLLoss(nn.Module): + """ + Loss function taking the MSE between the 2D DCT coefficients + of predicted and target images or an image channel. + DCT coefficients are computed for each 8x8 block of the image + + Important note about this loss: Since it operates in the frequency domain, precision is highly important. + It works on FP64 numbers and will fail if you attempt to use it with AMP. Recommend you split this loss + off from the rest of amp.scale_loss(). + """ + + def __init__(self, dataset_diff_means_file, device): + """ + dataset_diff_means (torch.tensor): Pre-computed frequency-domain mean differences between LR and HR images. + """ + # These values are derived from the JPEG standard. + qt_Y = torch.tensor([[16, 11, 10, 16, 24, 40, 51, 61], + [12, 12, 14, 19, 26, 58, 60, 55], + [14, 13, 16, 24, 40, 57, 69, 56], + [14, 17, 22, 29, 51, 87, 80, 62], + [18, 22, 37, 56, 68, 109, 103, 77], + [24, 35, 55, 64, 81, 104, 113, 92], + [49, 64, 78, 87, 103, 121, 120, 101], + [72, 92, 95, 98, 112, 100, 103, 99]], + dtype=torch.double, + device=device, + requires_grad=False) + qt_C = torch.tensor([[17, 18, 24, 47, 99, 99, 99, 99], + [18, 21, 26, 66, 99, 99, 99, 99], + [24, 26, 56, 99, 99, 99, 99, 99], + [47, 66, 99, 99, 99, 99, 99, 99], + [99, 99, 99, 99, 99, 99, 99, 99], + [99, 99, 99, 99, 99, 99, 99, 99], + [99, 99, 99, 99, 99, 99, 99, 99], + [99, 99, 99, 99, 99, 99, 99, 99]], + dtype=torch.double, + device=device, + requires_grad=False) + """ + Reasoning behind this perceptual weight matrix: JPEG gives as a model of frequencies that are important + for human perception. In that model, lower frequencies are more important than higher frequencies. Because + of this, the higher frequencies are the first to go during compression. As compression increases, the affect + spreads to the lower frequencies, which degrades perceptual quality. But when the lower frequencies are + preserved, JPEG does an excellent job of compression without a noticeable loss of quality. + In super resolution, we already have the low frequencies. In fact that is really all we have in the low + resolution images. + As evidenced by the diff_means matrix above, what is lost in the SR images is the mid-range frequencies; + those across and towards the centre of the diagonal. We can bias our model to recover these frequencies + by having our loss function prioritize these coefficients, with priority determined by the magnitude of + relative change between the low-res and high-res images. But we can take this further and into a true + preceptual loss by further prioritizing DCT coefficients by the importance that has been assigned to them + by the JPEG quantization table. That is how the table below is created. + + The problem is that we don't know if the JPEG model is optimal. So there is room for qualitative evaluation + of the quantization table values. We can further our perspective weights deleting each in turn for a small + set of images and evaluating the resulting change in percieved quality. I can do this on my own to start and + if it works, I can do a small user study to determine this. + """ + diff_means = torch.tensor(torch.load(dataset_diff_means_file), device=device) + perceptual_weights = torch.stack([(torch.ones_like(qt_Y, device=device) / qt_Y), + (torch.ones_like(qt_C, device=device) / qt_C), + (torch.ones_like(qt_C, device=device) / qt_C)]) + perceptual_weights = perceptual_weights * diff_means + self.perceptual_weights = perceptual_weights / torch.mean(perceptual_weights) + super(FDPLLoss, self).__init__() + + def forward(self, predictions, targets): + """ + Args: + predictions (torch.tensor): output of an image transformation model. + shape: batch_size x 3 x H x W + targets (torch.tensor): ground truth images corresponding to outputs + shape: batch_size x 3 x H x W + criterion (torch.nn.MSELoss): object used to calculate MSE + + Returns: + loss (float): MSE between predicted and ground truth 2D DCT coefficients + """ + # transition to fp64 and then convert to YCC color space. + predictions = rgb2ycbcr(predictions.double()) + targets = rgb2ycbcr(targets.double()) + + # get DCT coefficients of ground truth patches + patches = extract_patches_2d(img=targets, patch_shape=(8, 8), batch_first=True) + ground_truth_dct = dct_2d(patches, norm='ortho') + + # get DCT coefficients of transformed images + patches = extract_patches_2d(img=predictions, patch_shape=(8, 8), batch_first=True) + outputs_dct = dct_2d(patches, norm='ortho') + loss = torch.sum(((outputs_dct - ground_truth_dct).pow(2)) * self.perceptual_weights) + return loss \ No newline at end of file diff --git a/codes/utils/colors.py b/codes/utils/colors.py new file mode 100644 index 00000000..edfb7dd8 --- /dev/null +++ b/codes/utils/colors.py @@ -0,0 +1,409 @@ +# Differentiable color conversions from https://github.com/TheZino/pytorch-color-conversions + +from warnings import warn + +import numpy as np +import torch +from scipy import linalg + +xyz_from_rgb = torch.Tensor([[0.412453, 0.357580, 0.180423], + [0.212671, 0.715160, 0.072169], + [0.019334, 0.119193, 0.950227]]) + +rgb_from_xyz = torch.Tensor(linalg.inv(xyz_from_rgb)) + +illuminants = \ + {"A": {'2': torch.Tensor([(1.098466069456375, 1, 0.3558228003436005)]), + '10': torch.Tensor([(1.111420406956693, 1, 0.3519978321919493)])}, + "D50": {'2': torch.Tensor([(0.9642119944211994, 1, 0.8251882845188288)]), + '10': torch.Tensor([(0.9672062750333777, 1, 0.8142801513128616)])}, + "D55": {'2': torch.Tensor([(0.956797052643698, 1, 0.9214805860173273)]), + '10': torch.Tensor([(0.9579665682254781, 1, 0.9092525159847462)])}, + "D65": {'2': torch.Tensor([(0.95047, 1., 1.08883)]), # This was: `lab_ref_white` + '10': torch.Tensor([(0.94809667673716, 1, 1.0730513595166162)])}, + "D75": {'2': torch.Tensor([(0.9497220898840717, 1, 1.226393520724154)]), + '10': torch.Tensor([(0.9441713925645873, 1, 1.2064272211720228)])}, + "E": {'2': torch.Tensor([(1.0, 1.0, 1.0)]), + '10': torch.Tensor([(1.0, 1.0, 1.0)])}} + + +# ------------------------------------------------------------- +# The conversion functions that make use of the matrices above +# ------------------------------------------------------------- + + +##### RGB - YCbCr + +# Helper for the creation of module-global constant tensors +def _t(data): + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # TODO inherit this + device = torch.device("cpu") # TODO inherit this + return torch.tensor(data, requires_grad=False, dtype=torch.float32, device=device) + +# Helper for color matrix multiplication +def _mul(coeffs, image): + # This is implementation is clearly suboptimal. The function will + # be implemented with 'einsum' when a bug in pytorch 0.4.0 will be + # fixed (Einsum modifies variables in-place #7763). + coeffs = coeffs.to(image.device) + r0 = image[:, 0:1, :, :].repeat(1, 3, 1, 1) * coeffs[:, 0].view(1, 3, 1, 1) + r1 = image[:, 1:2, :, :].repeat(1, 3, 1, 1) * coeffs[:, 1].view(1, 3, 1, 1) + r2 = image[:, 2:3, :, :].repeat(1, 3, 1, 1) * coeffs[:, 2].view(1, 3, 1, 1) + return r0 + r1 + r2 + # return torch.einsum("dc,bcij->bdij", (coeffs.to(image.device), image)) + +_RGB_TO_YCBCR = _t([[0.257, 0.504, 0.098], [-0.148, -0.291, 0.439], [0.439 , -0.368, -0.071]]) +_YCBCR_OFF = _t([0.063, 0.502, 0.502]).view(1, 3, 1, 1) + + +def rgb2ycbcr(rgb): + """sRGB to YCbCr conversion.""" + clip_rgb=False + if clip_rgb: + rgb = torch.clamp(rgb, 0, 1) + return _mul(_RGB_TO_YCBCR, rgb) + _YCBCR_OFF.to(rgb.device) + + +def ycbcr2rgb(rgb): + """YCbCr to sRGB conversion.""" + clip_rgb=False + rgb = _mul(torch.inverse(_RGB_TO_YCBCR), rgb - _YCBCR_OFF.to(rgb.device)) + if clip_rgb: + rgb = torch.clamp(rgb, 0, 1) + return rgb + + +##### HSV - RGB + +def rgb2hsv(rgb): + """ + R, G and B input range = 0 ÷ 1.0 + H, S and V output range = 0 ÷ 1.0 + """ + eps = 1e-7 + + var_R = rgb[:,0,:,:] + var_G = rgb[:,1,:,:] + var_B = rgb[:,2,:,:] + + var_Min = rgb.min(1)[0] #Min. value of RGB + var_Max = rgb.max(1)[0] #Max. value of RGB + del_Max = var_Max - var_Min ##Delta RGB value + + H = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device) + S = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device) + V = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device) + + V = var_Max + + #This is a gray, no chroma... + mask = del_Max == 0 + H[mask] = 0 + S[mask] = 0 + + #Chromatic data... + S = del_Max / (var_Max + eps) + + del_R = ( ( ( var_Max - var_R ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps) + del_G = ( ( ( var_Max - var_G ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps) + del_B = ( ( ( var_Max - var_B ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps) + + H = torch.where( var_R == var_Max , del_B - del_G, H) + H = torch.where( var_G == var_Max , ( 1 / 3 ) + del_R - del_B, H) + H = torch.where( var_B == var_Max ,( 2 / 3 ) + del_G - del_R, H) + + # if ( H < 0 ) H += 1 + # if ( H > 1 ) H -= 1 + + return torch.stack([H, S, V], 1) + +def hsv2rgb(hsv): + """ + H, S and V input range = 0 ÷ 1.0 + R, G and B output range = 0 ÷ 1.0 + """ + + eps = 1e-7 + + bb,cc,hh,ww = hsv.shape + H = hsv[:,0,:,:] + S = hsv[:,1,:,:] + V = hsv[:,2,:,:] + + # var_h = torch.zeros(bb,hh,ww) + # var_s = torch.zeros(bb,hh,ww) + # var_v = torch.zeros(bb,hh,ww) + + # var_r = torch.zeros(bb,hh,ww) + # var_g = torch.zeros(bb,hh,ww) + # var_b = torch.zeros(bb,hh,ww) + + # Grayscale + if (S == 0).all(): + + R = V + G = V + B = V + + # Chromatic data + else: + + var_h = H * 6 + + var_h[var_h == 6] = 0 #H must be < 1 + var_i = var_h.floor() #Or ... var_i = floor( var_h ) + var_1 = V * ( 1 - S ) + var_2 = V * ( 1 - S * ( var_h - var_i ) ) + var_3 = V * ( 1 - S * ( 1 - ( var_h - var_i ) ) ) + + # else { var_r = V ; var_g = var_1 ; var_b = var_2 } + var_r = V + var_g = var_1 + var_b = var_2 + + # var_i == 0 { var_r = V ; var_g = var_3 ; var_b = var_1 } + var_r = torch.where(var_i == 0, V, var_r) + var_g = torch.where(var_i == 0, var_3, var_g) + var_b = torch.where(var_i == 0, var_1, var_b) + + # else if ( var_i == 1 ) { var_r = var_2 ; var_g = V ; var_b = var_1 } + var_r = torch.where(var_i == 1, var_2, var_r) + var_g = torch.where(var_i == 1, V, var_g) + var_b = torch.where(var_i == 1, var_1, var_b) + + # else if ( var_i == 2 ) { var_r = var_1 ; var_g = V ; var_b = var_3 } + var_r = torch.where(var_i == 2, var_1, var_r) + var_g = torch.where(var_i == 2, V, var_g) + var_b = torch.where(var_i == 2, var_3, var_b) + + # else if ( var_i == 3 ) { var_r = var_1 ; var_g = var_2 ; var_b = V } + var_r = torch.where(var_i == 3, var_1, var_r) + var_g = torch.where(var_i == 3, var_2, var_g) + var_b = torch.where(var_i == 3, V, var_b) + + # else if ( var_i == 4 ) { var_r = var_3 ; var_g = var_1 ; var_b = V } + var_r = torch.where(var_i == 4, var_3, var_r) + var_g = torch.where(var_i == 4, var_1, var_g) + var_b = torch.where(var_i == 4, V, var_b) + + + R = var_r #* 255 + G = var_g #* 255 + B = var_b #* 255 + + + return torch.stack([R, G, B], 1) + + +##### LAB - RGB + +def _convert(matrix, arr): + """Do the color space conversion. + Parameters + ---------- + matrix : array_like + The 3x3 matrix to use. + arr : array_like + The input array. + Returns + ------- + out : ndarray, dtype=float + The converted array. + """ + + if arr.is_cuda: + matrix = matrix.cuda() + + bs, ch, h, w = arr.shape + + arr = arr.permute((0,2,3,1)) + arr = arr.contiguous().view(-1,1,3) + + matrix = matrix.transpose(0,1).unsqueeze(0) + matrix = matrix.repeat(arr.shape[0],1,1) + + res = torch.bmm(arr,matrix) + + res = res.view(bs,h,w,ch) + res = res.transpose(3,2).transpose(2,1) + + + return res + +def get_xyz_coords(illuminant, observer): + """Get the XYZ coordinates of the given illuminant and observer [1]_. + Parameters + ---------- + illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional + The name of the illuminant (the function is NOT case sensitive). + observer : {"2", "10"}, optional + The aperture angle of the observer. + Returns + ------- + (x, y, z) : tuple + A tuple with 3 elements containing the XYZ coordinates of the given + illuminant. + Raises + ------ + ValueError + If either the illuminant or the observer angle are not supported or + unknown. + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Standard_illuminant + """ + illuminant = illuminant.upper() + try: + return illuminants[illuminant][observer] + except KeyError: + raise ValueError("Unknown illuminant/observer combination\ + (\'{0}\', \'{1}\')".format(illuminant, observer)) + + + + + +def rgb2xyz(rgb): + + mask = rgb > 0.04045 + rgbm = rgb.clone() + tmp = torch.pow((rgb + 0.055) / 1.055, 2.4) + rgb = torch.where(mask, tmp, rgb) + + rgbm = rgb.clone() + rgb[~mask] = rgbm[~mask]/12.92 + return _convert(xyz_from_rgb, rgb) + +def xyz2lab(xyz, illuminant="D65", observer="2"): + + # arr = _prepare_colorarray(xyz) + xyz_ref_white = get_xyz_coords(illuminant, observer) + #cuda + if xyz.is_cuda: + xyz_ref_white = xyz_ref_white.cuda() + + # scale by CIE XYZ tristimulus values of the reference white point + xyz = xyz / xyz_ref_white.view(1,3,1,1) + # Nonlinear distortion and linear transformation + mask = xyz > 0.008856 + xyzm = xyz.clone() + xyz[mask] = torch.pow(xyzm[mask], 1/3) + xyzm = xyz.clone() + xyz[~mask] = 7.787 * xyzm[~mask] + 16. / 116. + x, y, z = xyz[:, 0, :, :], xyz[:, 1, :, :], xyz[:, 2, :, :] + # Vector scaling + L = (116. * y) - 16. + a = 500.0 * (x - y) + b = 200.0 * (y - z) + return torch.stack((L,a,b), 1) + +def rgb2lab(rgb, illuminant="D65", observer="2"): + """RGB to lab color space conversion. + Parameters + ---------- + rgb : array_like + The image in RGB format, in a 3- or 4-D array of shape + ``(.., ..,[ ..,] 3)``. + illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional + The name of the illuminant (the function is NOT case sensitive). + observer : {"2", "10"}, optional + The aperture angle of the observer. + Returns + ------- + out : ndarray + The image in Lab format, in a 3- or 4-D array of shape + ``(.., ..,[ ..,] 3)``. + Raises + ------ + ValueError + If `rgb` is not a 3- or 4-D array of shape ``(.., ..,[ ..,] 3)``. + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Standard_illuminant + Notes + ----- + This function uses rgb2xyz and xyz2lab. + By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values + x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for + a list of supported illuminants. + """ + return xyz2lab(rgb2xyz(rgb), illuminant, observer) + + + + + +def lab2xyz(lab, illuminant="D65", observer="2"): + arr = lab.clone() + L, a, b = arr[:, 0, :, :], arr[:, 1, :, :], arr[:, 2, :, :] + y = (L + 16.) / 116. + x = (a / 500.) + y + z = y - (b / 200.) + + # if (z < 0).sum() > 0: + # warn('Color data out of range: Z < 0 in %s pixels' % (z < 0).sum().item()) + # z[z < 0] = 0 # NO GRADIENT!!!! + + out = torch.stack((x, y, z),1) + + mask = out > 0.2068966 + outm = out.clone() + out[mask] = torch.pow(outm[mask], 3.) + outm = out.clone() + out[~mask] = (outm[~mask] - 16.0 / 116.) / 7.787 + + # rescale to the reference white (illuminant) + xyz_ref_white = get_xyz_coords(illuminant, observer) + # cuda + if lab.is_cuda: + xyz_ref_white = xyz_ref_white.cuda() + xyz_ref_white = xyz_ref_white.unsqueeze(2).unsqueeze(2).repeat(1,1,out.shape[2],out.shape[3]) + out = out * xyz_ref_white + return out + +def xyz2rgb(xyz): + arr = _convert(rgb_from_xyz, xyz) + mask = arr > 0.0031308 + arrm = arr.clone() + arr[mask] = 1.055 * torch.pow(arrm[mask], 1 / 2.4) - 0.055 + arrm = arr.clone() + arr[~mask] = arrm[~mask] * 12.92 + + # CLAMP KILLS GRADIENTS + # mask_z = arr < 0 + # arr[mask_z] = 0 + # mask_o = arr > 1 + # arr[mask_o] = 1 + + # torch.clamp(arr, 0, 1, out=arr) + return arr + +def lab2rgb(lab, illuminant="D65", observer="2"): + """Lab to RGB color space conversion. + Parameters + ---------- + lab : array_like + The image in Lab format, in a 3-D array of shape ``(.., .., 3)``. + illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional + The name of the illuminant (the function is NOT case sensitive). + observer : {"2", "10"}, optional + The aperture angle of the observer. + Returns + ------- + out : ndarray + The image in RGB format, in a 3-D array of shape ``(.., .., 3)``. + Raises + ------ + ValueError + If `lab` is not a 3-D array of shape ``(.., .., 3)``. + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Standard_illuminant + Notes + ----- + This function uses lab2xyz and xyz2rgb. + By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values + x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for + a list of supported illuminants. + """ + return xyz2rgb(lab2xyz(lab, illuminant, observer)) diff --git a/codes/utils/fdpl_util.py b/codes/utils/fdpl_util.py new file mode 100644 index 00000000..2ec8cd11 --- /dev/null +++ b/codes/utils/fdpl_util.py @@ -0,0 +1,136 @@ +import numpy as np +import torch +import torch.nn as nn + +# note: all dct related functions are either exactly as or based on those +# at https://github.com/zh217/torch-dct +def dct(x, norm=None): + """ + Discrete Cosine Transform, Type II (a.k.a. the DCT) + For the meaning of the parameter `norm`, see: + https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html + :param x: the input signal + :param norm: the normalization, None or 'ortho' + :return: the DCT-II of the signal over the last dimension + """ + x_shape = x.shape + N = x_shape[-1] + x = x.contiguous().view(-1, N) + + v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) + + Vc = torch.rfft(v, 1, onesided=False) + + k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) + W_r = torch.cos(k) + W_i = torch.sin(k) + + V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i + + if norm == 'ortho': + V[:, 0] /= np.sqrt(N) * 2 + V[:, 1:] /= np.sqrt(N / 2) * 2 + + V = 2 * V.view(*x_shape) + + return V + +def idct(X, norm=None): + """ + The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III + Our definition of idct is that idct(dct(x)) == x + For the meaning of the parameter `norm`, see: + https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html + :param X: the input signal + :param norm: the normalization, None or 'ortho' + :return: the inverse DCT-II of the signal over the last dimension + """ + + x_shape = X.shape + N = x_shape[-1] + + X_v = X.contiguous().view(-1, x_shape[-1]) / 2 + + if norm == 'ortho': + X_v[:, 0] *= np.sqrt(N) * 2 + X_v[:, 1:] *= np.sqrt(N / 2) * 2 + + k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N) + W_r = torch.cos(k) + W_i = torch.sin(k) + + V_t_r = X_v + V_t_r = V_t_r.to(device) + V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) + V_t_i = V_t_i.to(device) + + V_r = V_t_r * W_r - V_t_i * W_i + V_i = V_t_r * W_i + V_t_i * W_r + + V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) + + v = torch.irfft(V, 1, onesided=False) + x = v.new_zeros(v.shape) + x[:, ::2] += v[:, :N - (N // 2)] + x[:, 1::2] += v.flip([1])[:, :N // 2] + + return x.view(*x_shape) + +def dct_2d(x, norm=None): + """ + 2-dimensional Discrete Cosine Transform, Type II (a.k.a. the DCT) + For the meaning of the parameter `norm`, see: + https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html + :param x: the input signal + :param norm: the normalization, None or 'ortho' + :return: the DCT-II of the signal over the last 2 dimensions + """ + X1 = dct(x, norm=norm) + X2 = dct(X1.transpose(-1, -2), norm=norm) + return X2.transpose(-1, -2) + +def idct_2d(X, norm=None): + """ + The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III + Our definition of idct is that idct_2d(dct_2d(x)) == x + For the meaning of the parameter `norm`, see: + https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html + :param X: the input signal + :param norm: the normalization, None or 'ortho' + :return: the DCT-II of the signal over the last 2 dimensions + """ + x1 = idct(X, norm=norm) + x2 = idct(x1.transpose(-1, -2), norm=norm) + return x2.transpose(-1, -2) + +def extract_patches_2d(img,patch_shape,step=[1.0,1.0],batch_first=False): + """ + source: https://gist.github.com/dem123456789/23f18fd78ac8da9615c347905e64fc78 + """ + patch_H, patch_W = patch_shape[0], patch_shape[1] + if(img.size(2) < patch_H): + num_padded_H_Top = (patch_H - img.size(2))//2 + num_padded_H_Bottom = patch_H - img.size(2) - num_padded_H_Top + padding_H = nn.ConstantPad2d((0, 0, num_padded_H_Top, num_padded_H_Bottom), 0) + img = padding_H(img) + if(img.size(3) < patch_W): + num_padded_W_Left = (patch_W - img.size(3))//2 + num_padded_W_Right = patch_W - img.size(3) - num_padded_W_Left + padding_W = nn.ConstantPad2d((num_padded_W_Left,num_padded_W_Right, 0, 0), 0) + img = padding_W(img) + step_int = [0, 0] + step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0] + step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1] + patches_fold_H = img.unfold(2, patch_H, step_int[0]) + if((img.size(2) - patch_H) % step_int[0] != 0): + patches_fold_H = torch.cat((patches_fold_H, + img[:, :, -patch_H:, :].permute(0,1,3,2).unsqueeze(2)),dim=2) + patches_fold_HW = patches_fold_H.unfold(3, patch_W, step_int[1]) + if((img.size(3) - patch_W) % step_int[1] != 0): + patches_fold_HW = torch.cat((patches_fold_HW, + patches_fold_H[:, :, :, -patch_W:, :].permute(0, 1, 2, 4, 3).unsqueeze(3)), dim=3) + patches = patches_fold_HW.permute(2, 3, 0, 1, 4, 5) + patches = patches.reshape(-1, img.size(0), img.size(1), patch_H, patch_W) + if(batch_first): + patches = patches.permute(1, 0, 2, 3, 4) + return patches