forked from mrq/DL-Art-School
Add FDPL Loss
New loss type that can replace PSNR loss. Works against the frequency domain and focuses on frequency features loss during hr->lr conversion.
This commit is contained in:
parent
85ee64b8d9
commit
7629cb0e61
|
@ -172,7 +172,7 @@ class LQGTDataset(data.Dataset):
|
||||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||||
img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
if self.opt['doResizeLoss']:
|
if 'doResizeLoss' in self.opt.keys() and self.opt['doResizeLoss']:
|
||||||
r = random.randrange(0, 10)
|
r = random.randrange(0, 10)
|
||||||
if r > 5:
|
if r > 5:
|
||||||
img_LQ = cv2.resize(img_LQ, (int(LQ_size/2), int(LQ_size/2)), interpolation=cv2.INTER_LINEAR)
|
img_LQ = cv2.resize(img_LQ, (int(LQ_size/2), int(LQ_size/2)), interpolation=cv2.INTER_LINEAR)
|
||||||
|
@ -215,7 +215,7 @@ class LQGTDataset(data.Dataset):
|
||||||
corruption_buffer.seek(0)
|
corruption_buffer.seek(0)
|
||||||
img_LQ = Image.open(corruption_buffer)
|
img_LQ = Image.open(corruption_buffer)
|
||||||
|
|
||||||
if self.opt['grayscale']:
|
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
|
||||||
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
|
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
|
||||||
|
|
||||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||||
|
|
74
codes/data_scripts/compute_fdpl_perceptual_weights.py
Normal file
74
codes/data_scripts/compute_fdpl_perceptual_weights.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import options.options as option
|
||||||
|
from data import create_dataloader, create_dataset
|
||||||
|
import math
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torchvision import transforms
|
||||||
|
from utils.fdpl_util import dct_2d, extract_patches_2d
|
||||||
|
import random
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||||
|
from utils.colors import rgb2ycbcr
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
input_config = "../../options/train_imgset_pixgan_srg4_fdpl.yml"
|
||||||
|
output_file = "fdpr_diff_means.pt"
|
||||||
|
device = 'cuda'
|
||||||
|
patch_size=128
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
opt = option.parse(input_config, is_train=True)
|
||||||
|
opt['dist'] = False
|
||||||
|
|
||||||
|
# Create a dataset to load from (this dataset loads HR/LR images and performs any distortions specified by the YML.
|
||||||
|
dataset_opt = opt['datasets']['train']
|
||||||
|
train_set = create_dataset(dataset_opt)
|
||||||
|
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
|
||||||
|
total_iters = int(opt['train']['niter'])
|
||||||
|
total_epochs = int(math.ceil(total_iters / train_size))
|
||||||
|
train_loader = create_dataloader(train_set, dataset_opt, opt, None)
|
||||||
|
print('Number of train images: {:,d}, iters: {:,d}'.format(
|
||||||
|
len(train_set), train_size))
|
||||||
|
|
||||||
|
# calculate the perceptual weights
|
||||||
|
master_diff = np.zeros((patch_size, patch_size))
|
||||||
|
num_patches = 0
|
||||||
|
all_diff_patches = []
|
||||||
|
tq = tqdm(train_loader)
|
||||||
|
sampled = 0
|
||||||
|
for train_data in tq:
|
||||||
|
if sampled > 200:
|
||||||
|
break
|
||||||
|
sampled += 1
|
||||||
|
|
||||||
|
im = rgb2ycbcr(train_data['GT'].double())
|
||||||
|
im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(),
|
||||||
|
size=im.shape[2:],
|
||||||
|
mode="bicubic"))
|
||||||
|
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)
|
||||||
|
patches_hr = dct_2d(patches_hr, norm='ortho')
|
||||||
|
patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True)
|
||||||
|
patches_lr = dct_2d(patches_lr, norm='ortho')
|
||||||
|
b, p, c, w, h = patches_hr.shape
|
||||||
|
diffs = torch.abs(patches_lr - patches_hr) / ((torch.abs(patches_lr) + torch.abs(patches_hr)) / 2 + .00000001)
|
||||||
|
num_patches += b * p
|
||||||
|
all_diff_patches.append(torch.sum(diffs, dim=(0, 1)))
|
||||||
|
|
||||||
|
diff_patches = torch.stack(all_diff_patches, dim=0)
|
||||||
|
diff_means = torch.sum(diff_patches, dim=0) / num_patches
|
||||||
|
|
||||||
|
torch.save(diff_means, output_file)
|
||||||
|
print(diff_means)
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
divider = make_axes_locatable(ax)
|
||||||
|
cax = divider.append_axes('right', size='5%', pad=0.05)
|
||||||
|
im = ax.imshow(diff_means[i].numpy())
|
||||||
|
ax.set_title("mean_diff for channel %i" % (i,))
|
||||||
|
fig.colorbar(im, cax=cax, orientation='vertical')
|
||||||
|
plt.show()
|
||||||
|
|
|
@ -6,7 +6,7 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||||
import models.networks as networks
|
import models.networks as networks
|
||||||
import models.lr_scheduler as lr_scheduler
|
import models.lr_scheduler as lr_scheduler
|
||||||
from models.base_model import BaseModel
|
from models.base_model import BaseModel
|
||||||
from models.loss import GANLoss
|
from models.loss import GANLoss, FDPLLoss
|
||||||
from apex import amp
|
from apex import amp
|
||||||
from data.weight_scheduler import get_scheduler_for_opt
|
from data.weight_scheduler import get_scheduler_for_opt
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -62,6 +62,16 @@ class SRGANModel(BaseModel):
|
||||||
logger.info('Remove pixel loss.')
|
logger.info('Remove pixel loss.')
|
||||||
self.cri_pix = None
|
self.cri_pix = None
|
||||||
|
|
||||||
|
# FDPL loss.
|
||||||
|
if 'fdpl_loss' in train_opt.keys():
|
||||||
|
fdpl_opt = train_opt['fdpl_loss']
|
||||||
|
self.fdpl_weight = fdpl_opt['weight']
|
||||||
|
self.fdpl_enabled = self.fdpl_weight > 0
|
||||||
|
if self.fdpl_enabled:
|
||||||
|
self.cri_fdpl = FDPLLoss(fdpl_opt['data_mean'], self.device)
|
||||||
|
else:
|
||||||
|
self.fdpl_enabled = False
|
||||||
|
|
||||||
# G feature loss
|
# G feature loss
|
||||||
if train_opt['feature_weight'] and train_opt['feature_weight'] > 0:
|
if train_opt['feature_weight'] and train_opt['feature_weight'] > 0:
|
||||||
# For backwards compatibility, use a scheduler definition instead. Remove this at some point.
|
# For backwards compatibility, use a scheduler definition instead. Remove this at some point.
|
||||||
|
@ -305,10 +315,14 @@ class SRGANModel(BaseModel):
|
||||||
if using_gan_img:
|
if using_gan_img:
|
||||||
l_g_pix_log = None
|
l_g_pix_log = None
|
||||||
l_g_fea_log = None
|
l_g_fea_log = None
|
||||||
|
l_g_fdpl = None
|
||||||
if self.cri_pix and not using_gan_img: # pixel loss
|
if self.cri_pix and not using_gan_img: # pixel loss
|
||||||
l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix)
|
l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix)
|
||||||
l_g_pix_log = l_g_pix / self.l_pix_w
|
l_g_pix_log = l_g_pix / self.l_pix_w
|
||||||
l_g_total += l_g_pix
|
l_g_total += l_g_pix
|
||||||
|
if self.fdpl_enabled and not using_gan_img:
|
||||||
|
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
|
||||||
|
l_g_total += l_g_fdpl * self.fdpl_weight
|
||||||
if self.cri_fea and not using_gan_img: # feature loss
|
if self.cri_fea and not using_gan_img: # feature loss
|
||||||
real_fea = self.netF(pix).detach()
|
real_fea = self.netF(pix).detach()
|
||||||
fake_fea = self.netF(fea_GenOut)
|
fake_fea = self.netF(fea_GenOut)
|
||||||
|
@ -535,6 +549,8 @@ class SRGANModel(BaseModel):
|
||||||
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
||||||
if self.cri_pix and l_g_pix_log is not None:
|
if self.cri_pix and l_g_pix_log is not None:
|
||||||
self.add_log_entry('l_g_pix', l_g_pix_log.item())
|
self.add_log_entry('l_g_pix', l_g_pix_log.item())
|
||||||
|
if self.fdpl_enabled and l_g_fdpl is not None:
|
||||||
|
self.add_log_entry('l_g_fdpl', l_g_fdpl.item())
|
||||||
if self.cri_fea and l_g_fea_log is not None:
|
if self.cri_fea and l_g_fea_log is not None:
|
||||||
self.add_log_entry('feature_weight', fea_w)
|
self.add_log_entry('feature_weight', fea_w)
|
||||||
self.add_log_entry('l_g_fea', l_g_fea_log.item())
|
self.add_log_entry('l_g_fea', l_g_fea_log.item())
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from utils.fdpl_util import extract_patches_2d, dct_2d
|
||||||
|
from utils.colors import rgb2ycbcr
|
||||||
|
|
||||||
|
|
||||||
class CharbonnierLoss(nn.Module):
|
class CharbonnierLoss(nn.Module):
|
||||||
|
@ -75,3 +78,99 @@ class GradientPenaltyLoss(nn.Module):
|
||||||
|
|
||||||
loss = ((grad_interp_norm - 1)**2).mean()
|
loss = ((grad_interp_norm - 1)**2).mean()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
# Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL
|
||||||
|
# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see data_scripts/compute_fdpl_perceptual_weights.py
|
||||||
|
# In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs).
|
||||||
|
class FDPLLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Loss function taking the MSE between the 2D DCT coefficients
|
||||||
|
of predicted and target images or an image channel.
|
||||||
|
DCT coefficients are computed for each 8x8 block of the image
|
||||||
|
|
||||||
|
Important note about this loss: Since it operates in the frequency domain, precision is highly important.
|
||||||
|
It works on FP64 numbers and will fail if you attempt to use it with AMP. Recommend you split this loss
|
||||||
|
off from the rest of amp.scale_loss().
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset_diff_means_file, device):
|
||||||
|
"""
|
||||||
|
dataset_diff_means (torch.tensor): Pre-computed frequency-domain mean differences between LR and HR images.
|
||||||
|
"""
|
||||||
|
# These values are derived from the JPEG standard.
|
||||||
|
qt_Y = torch.tensor([[16, 11, 10, 16, 24, 40, 51, 61],
|
||||||
|
[12, 12, 14, 19, 26, 58, 60, 55],
|
||||||
|
[14, 13, 16, 24, 40, 57, 69, 56],
|
||||||
|
[14, 17, 22, 29, 51, 87, 80, 62],
|
||||||
|
[18, 22, 37, 56, 68, 109, 103, 77],
|
||||||
|
[24, 35, 55, 64, 81, 104, 113, 92],
|
||||||
|
[49, 64, 78, 87, 103, 121, 120, 101],
|
||||||
|
[72, 92, 95, 98, 112, 100, 103, 99]],
|
||||||
|
dtype=torch.double,
|
||||||
|
device=device,
|
||||||
|
requires_grad=False)
|
||||||
|
qt_C = torch.tensor([[17, 18, 24, 47, 99, 99, 99, 99],
|
||||||
|
[18, 21, 26, 66, 99, 99, 99, 99],
|
||||||
|
[24, 26, 56, 99, 99, 99, 99, 99],
|
||||||
|
[47, 66, 99, 99, 99, 99, 99, 99],
|
||||||
|
[99, 99, 99, 99, 99, 99, 99, 99],
|
||||||
|
[99, 99, 99, 99, 99, 99, 99, 99],
|
||||||
|
[99, 99, 99, 99, 99, 99, 99, 99],
|
||||||
|
[99, 99, 99, 99, 99, 99, 99, 99]],
|
||||||
|
dtype=torch.double,
|
||||||
|
device=device,
|
||||||
|
requires_grad=False)
|
||||||
|
"""
|
||||||
|
Reasoning behind this perceptual weight matrix: JPEG gives as a model of frequencies that are important
|
||||||
|
for human perception. In that model, lower frequencies are more important than higher frequencies. Because
|
||||||
|
of this, the higher frequencies are the first to go during compression. As compression increases, the affect
|
||||||
|
spreads to the lower frequencies, which degrades perceptual quality. But when the lower frequencies are
|
||||||
|
preserved, JPEG does an excellent job of compression without a noticeable loss of quality.
|
||||||
|
In super resolution, we already have the low frequencies. In fact that is really all we have in the low
|
||||||
|
resolution images.
|
||||||
|
As evidenced by the diff_means matrix above, what is lost in the SR images is the mid-range frequencies;
|
||||||
|
those across and towards the centre of the diagonal. We can bias our model to recover these frequencies
|
||||||
|
by having our loss function prioritize these coefficients, with priority determined by the magnitude of
|
||||||
|
relative change between the low-res and high-res images. But we can take this further and into a true
|
||||||
|
preceptual loss by further prioritizing DCT coefficients by the importance that has been assigned to them
|
||||||
|
by the JPEG quantization table. That is how the table below is created.
|
||||||
|
|
||||||
|
The problem is that we don't know if the JPEG model is optimal. So there is room for qualitative evaluation
|
||||||
|
of the quantization table values. We can further our perspective weights deleting each in turn for a small
|
||||||
|
set of images and evaluating the resulting change in percieved quality. I can do this on my own to start and
|
||||||
|
if it works, I can do a small user study to determine this.
|
||||||
|
"""
|
||||||
|
diff_means = torch.tensor(torch.load(dataset_diff_means_file), device=device)
|
||||||
|
perceptual_weights = torch.stack([(torch.ones_like(qt_Y, device=device) / qt_Y),
|
||||||
|
(torch.ones_like(qt_C, device=device) / qt_C),
|
||||||
|
(torch.ones_like(qt_C, device=device) / qt_C)])
|
||||||
|
perceptual_weights = perceptual_weights * diff_means
|
||||||
|
self.perceptual_weights = perceptual_weights / torch.mean(perceptual_weights)
|
||||||
|
super(FDPLLoss, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, predictions, targets):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
predictions (torch.tensor): output of an image transformation model.
|
||||||
|
shape: batch_size x 3 x H x W
|
||||||
|
targets (torch.tensor): ground truth images corresponding to outputs
|
||||||
|
shape: batch_size x 3 x H x W
|
||||||
|
criterion (torch.nn.MSELoss): object used to calculate MSE
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
loss (float): MSE between predicted and ground truth 2D DCT coefficients
|
||||||
|
"""
|
||||||
|
# transition to fp64 and then convert to YCC color space.
|
||||||
|
predictions = rgb2ycbcr(predictions.double())
|
||||||
|
targets = rgb2ycbcr(targets.double())
|
||||||
|
|
||||||
|
# get DCT coefficients of ground truth patches
|
||||||
|
patches = extract_patches_2d(img=targets, patch_shape=(8, 8), batch_first=True)
|
||||||
|
ground_truth_dct = dct_2d(patches, norm='ortho')
|
||||||
|
|
||||||
|
# get DCT coefficients of transformed images
|
||||||
|
patches = extract_patches_2d(img=predictions, patch_shape=(8, 8), batch_first=True)
|
||||||
|
outputs_dct = dct_2d(patches, norm='ortho')
|
||||||
|
loss = torch.sum(((outputs_dct - ground_truth_dct).pow(2)) * self.perceptual_weights)
|
||||||
|
return loss
|
409
codes/utils/colors.py
Normal file
409
codes/utils/colors.py
Normal file
|
@ -0,0 +1,409 @@
|
||||||
|
# Differentiable color conversions from https://github.com/TheZino/pytorch-color-conversions
|
||||||
|
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from scipy import linalg
|
||||||
|
|
||||||
|
xyz_from_rgb = torch.Tensor([[0.412453, 0.357580, 0.180423],
|
||||||
|
[0.212671, 0.715160, 0.072169],
|
||||||
|
[0.019334, 0.119193, 0.950227]])
|
||||||
|
|
||||||
|
rgb_from_xyz = torch.Tensor(linalg.inv(xyz_from_rgb))
|
||||||
|
|
||||||
|
illuminants = \
|
||||||
|
{"A": {'2': torch.Tensor([(1.098466069456375, 1, 0.3558228003436005)]),
|
||||||
|
'10': torch.Tensor([(1.111420406956693, 1, 0.3519978321919493)])},
|
||||||
|
"D50": {'2': torch.Tensor([(0.9642119944211994, 1, 0.8251882845188288)]),
|
||||||
|
'10': torch.Tensor([(0.9672062750333777, 1, 0.8142801513128616)])},
|
||||||
|
"D55": {'2': torch.Tensor([(0.956797052643698, 1, 0.9214805860173273)]),
|
||||||
|
'10': torch.Tensor([(0.9579665682254781, 1, 0.9092525159847462)])},
|
||||||
|
"D65": {'2': torch.Tensor([(0.95047, 1., 1.08883)]), # This was: `lab_ref_white`
|
||||||
|
'10': torch.Tensor([(0.94809667673716, 1, 1.0730513595166162)])},
|
||||||
|
"D75": {'2': torch.Tensor([(0.9497220898840717, 1, 1.226393520724154)]),
|
||||||
|
'10': torch.Tensor([(0.9441713925645873, 1, 1.2064272211720228)])},
|
||||||
|
"E": {'2': torch.Tensor([(1.0, 1.0, 1.0)]),
|
||||||
|
'10': torch.Tensor([(1.0, 1.0, 1.0)])}}
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------------
|
||||||
|
# The conversion functions that make use of the matrices above
|
||||||
|
# -------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
##### RGB - YCbCr
|
||||||
|
|
||||||
|
# Helper for the creation of module-global constant tensors
|
||||||
|
def _t(data):
|
||||||
|
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # TODO inherit this
|
||||||
|
device = torch.device("cpu") # TODO inherit this
|
||||||
|
return torch.tensor(data, requires_grad=False, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
# Helper for color matrix multiplication
|
||||||
|
def _mul(coeffs, image):
|
||||||
|
# This is implementation is clearly suboptimal. The function will
|
||||||
|
# be implemented with 'einsum' when a bug in pytorch 0.4.0 will be
|
||||||
|
# fixed (Einsum modifies variables in-place #7763).
|
||||||
|
coeffs = coeffs.to(image.device)
|
||||||
|
r0 = image[:, 0:1, :, :].repeat(1, 3, 1, 1) * coeffs[:, 0].view(1, 3, 1, 1)
|
||||||
|
r1 = image[:, 1:2, :, :].repeat(1, 3, 1, 1) * coeffs[:, 1].view(1, 3, 1, 1)
|
||||||
|
r2 = image[:, 2:3, :, :].repeat(1, 3, 1, 1) * coeffs[:, 2].view(1, 3, 1, 1)
|
||||||
|
return r0 + r1 + r2
|
||||||
|
# return torch.einsum("dc,bcij->bdij", (coeffs.to(image.device), image))
|
||||||
|
|
||||||
|
_RGB_TO_YCBCR = _t([[0.257, 0.504, 0.098], [-0.148, -0.291, 0.439], [0.439 , -0.368, -0.071]])
|
||||||
|
_YCBCR_OFF = _t([0.063, 0.502, 0.502]).view(1, 3, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def rgb2ycbcr(rgb):
|
||||||
|
"""sRGB to YCbCr conversion."""
|
||||||
|
clip_rgb=False
|
||||||
|
if clip_rgb:
|
||||||
|
rgb = torch.clamp(rgb, 0, 1)
|
||||||
|
return _mul(_RGB_TO_YCBCR, rgb) + _YCBCR_OFF.to(rgb.device)
|
||||||
|
|
||||||
|
|
||||||
|
def ycbcr2rgb(rgb):
|
||||||
|
"""YCbCr to sRGB conversion."""
|
||||||
|
clip_rgb=False
|
||||||
|
rgb = _mul(torch.inverse(_RGB_TO_YCBCR), rgb - _YCBCR_OFF.to(rgb.device))
|
||||||
|
if clip_rgb:
|
||||||
|
rgb = torch.clamp(rgb, 0, 1)
|
||||||
|
return rgb
|
||||||
|
|
||||||
|
|
||||||
|
##### HSV - RGB
|
||||||
|
|
||||||
|
def rgb2hsv(rgb):
|
||||||
|
"""
|
||||||
|
R, G and B input range = 0 ÷ 1.0
|
||||||
|
H, S and V output range = 0 ÷ 1.0
|
||||||
|
"""
|
||||||
|
eps = 1e-7
|
||||||
|
|
||||||
|
var_R = rgb[:,0,:,:]
|
||||||
|
var_G = rgb[:,1,:,:]
|
||||||
|
var_B = rgb[:,2,:,:]
|
||||||
|
|
||||||
|
var_Min = rgb.min(1)[0] #Min. value of RGB
|
||||||
|
var_Max = rgb.max(1)[0] #Max. value of RGB
|
||||||
|
del_Max = var_Max - var_Min ##Delta RGB value
|
||||||
|
|
||||||
|
H = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
|
||||||
|
S = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
|
||||||
|
V = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
|
||||||
|
|
||||||
|
V = var_Max
|
||||||
|
|
||||||
|
#This is a gray, no chroma...
|
||||||
|
mask = del_Max == 0
|
||||||
|
H[mask] = 0
|
||||||
|
S[mask] = 0
|
||||||
|
|
||||||
|
#Chromatic data...
|
||||||
|
S = del_Max / (var_Max + eps)
|
||||||
|
|
||||||
|
del_R = ( ( ( var_Max - var_R ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
|
||||||
|
del_G = ( ( ( var_Max - var_G ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
|
||||||
|
del_B = ( ( ( var_Max - var_B ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
|
||||||
|
|
||||||
|
H = torch.where( var_R == var_Max , del_B - del_G, H)
|
||||||
|
H = torch.where( var_G == var_Max , ( 1 / 3 ) + del_R - del_B, H)
|
||||||
|
H = torch.where( var_B == var_Max ,( 2 / 3 ) + del_G - del_R, H)
|
||||||
|
|
||||||
|
# if ( H < 0 ) H += 1
|
||||||
|
# if ( H > 1 ) H -= 1
|
||||||
|
|
||||||
|
return torch.stack([H, S, V], 1)
|
||||||
|
|
||||||
|
def hsv2rgb(hsv):
|
||||||
|
"""
|
||||||
|
H, S and V input range = 0 ÷ 1.0
|
||||||
|
R, G and B output range = 0 ÷ 1.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
eps = 1e-7
|
||||||
|
|
||||||
|
bb,cc,hh,ww = hsv.shape
|
||||||
|
H = hsv[:,0,:,:]
|
||||||
|
S = hsv[:,1,:,:]
|
||||||
|
V = hsv[:,2,:,:]
|
||||||
|
|
||||||
|
# var_h = torch.zeros(bb,hh,ww)
|
||||||
|
# var_s = torch.zeros(bb,hh,ww)
|
||||||
|
# var_v = torch.zeros(bb,hh,ww)
|
||||||
|
|
||||||
|
# var_r = torch.zeros(bb,hh,ww)
|
||||||
|
# var_g = torch.zeros(bb,hh,ww)
|
||||||
|
# var_b = torch.zeros(bb,hh,ww)
|
||||||
|
|
||||||
|
# Grayscale
|
||||||
|
if (S == 0).all():
|
||||||
|
|
||||||
|
R = V
|
||||||
|
G = V
|
||||||
|
B = V
|
||||||
|
|
||||||
|
# Chromatic data
|
||||||
|
else:
|
||||||
|
|
||||||
|
var_h = H * 6
|
||||||
|
|
||||||
|
var_h[var_h == 6] = 0 #H must be < 1
|
||||||
|
var_i = var_h.floor() #Or ... var_i = floor( var_h )
|
||||||
|
var_1 = V * ( 1 - S )
|
||||||
|
var_2 = V * ( 1 - S * ( var_h - var_i ) )
|
||||||
|
var_3 = V * ( 1 - S * ( 1 - ( var_h - var_i ) ) )
|
||||||
|
|
||||||
|
# else { var_r = V ; var_g = var_1 ; var_b = var_2 }
|
||||||
|
var_r = V
|
||||||
|
var_g = var_1
|
||||||
|
var_b = var_2
|
||||||
|
|
||||||
|
# var_i == 0 { var_r = V ; var_g = var_3 ; var_b = var_1 }
|
||||||
|
var_r = torch.where(var_i == 0, V, var_r)
|
||||||
|
var_g = torch.where(var_i == 0, var_3, var_g)
|
||||||
|
var_b = torch.where(var_i == 0, var_1, var_b)
|
||||||
|
|
||||||
|
# else if ( var_i == 1 ) { var_r = var_2 ; var_g = V ; var_b = var_1 }
|
||||||
|
var_r = torch.where(var_i == 1, var_2, var_r)
|
||||||
|
var_g = torch.where(var_i == 1, V, var_g)
|
||||||
|
var_b = torch.where(var_i == 1, var_1, var_b)
|
||||||
|
|
||||||
|
# else if ( var_i == 2 ) { var_r = var_1 ; var_g = V ; var_b = var_3 }
|
||||||
|
var_r = torch.where(var_i == 2, var_1, var_r)
|
||||||
|
var_g = torch.where(var_i == 2, V, var_g)
|
||||||
|
var_b = torch.where(var_i == 2, var_3, var_b)
|
||||||
|
|
||||||
|
# else if ( var_i == 3 ) { var_r = var_1 ; var_g = var_2 ; var_b = V }
|
||||||
|
var_r = torch.where(var_i == 3, var_1, var_r)
|
||||||
|
var_g = torch.where(var_i == 3, var_2, var_g)
|
||||||
|
var_b = torch.where(var_i == 3, V, var_b)
|
||||||
|
|
||||||
|
# else if ( var_i == 4 ) { var_r = var_3 ; var_g = var_1 ; var_b = V }
|
||||||
|
var_r = torch.where(var_i == 4, var_3, var_r)
|
||||||
|
var_g = torch.where(var_i == 4, var_1, var_g)
|
||||||
|
var_b = torch.where(var_i == 4, V, var_b)
|
||||||
|
|
||||||
|
|
||||||
|
R = var_r #* 255
|
||||||
|
G = var_g #* 255
|
||||||
|
B = var_b #* 255
|
||||||
|
|
||||||
|
|
||||||
|
return torch.stack([R, G, B], 1)
|
||||||
|
|
||||||
|
|
||||||
|
##### LAB - RGB
|
||||||
|
|
||||||
|
def _convert(matrix, arr):
|
||||||
|
"""Do the color space conversion.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
matrix : array_like
|
||||||
|
The 3x3 matrix to use.
|
||||||
|
arr : array_like
|
||||||
|
The input array.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
out : ndarray, dtype=float
|
||||||
|
The converted array.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if arr.is_cuda:
|
||||||
|
matrix = matrix.cuda()
|
||||||
|
|
||||||
|
bs, ch, h, w = arr.shape
|
||||||
|
|
||||||
|
arr = arr.permute((0,2,3,1))
|
||||||
|
arr = arr.contiguous().view(-1,1,3)
|
||||||
|
|
||||||
|
matrix = matrix.transpose(0,1).unsqueeze(0)
|
||||||
|
matrix = matrix.repeat(arr.shape[0],1,1)
|
||||||
|
|
||||||
|
res = torch.bmm(arr,matrix)
|
||||||
|
|
||||||
|
res = res.view(bs,h,w,ch)
|
||||||
|
res = res.transpose(3,2).transpose(2,1)
|
||||||
|
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get_xyz_coords(illuminant, observer):
|
||||||
|
"""Get the XYZ coordinates of the given illuminant and observer [1]_.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
|
||||||
|
The name of the illuminant (the function is NOT case sensitive).
|
||||||
|
observer : {"2", "10"}, optional
|
||||||
|
The aperture angle of the observer.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(x, y, z) : tuple
|
||||||
|
A tuple with 3 elements containing the XYZ coordinates of the given
|
||||||
|
illuminant.
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If either the illuminant or the observer angle are not supported or
|
||||||
|
unknown.
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
|
||||||
|
"""
|
||||||
|
illuminant = illuminant.upper()
|
||||||
|
try:
|
||||||
|
return illuminants[illuminant][observer]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError("Unknown illuminant/observer combination\
|
||||||
|
(\'{0}\', \'{1}\')".format(illuminant, observer))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def rgb2xyz(rgb):
|
||||||
|
|
||||||
|
mask = rgb > 0.04045
|
||||||
|
rgbm = rgb.clone()
|
||||||
|
tmp = torch.pow((rgb + 0.055) / 1.055, 2.4)
|
||||||
|
rgb = torch.where(mask, tmp, rgb)
|
||||||
|
|
||||||
|
rgbm = rgb.clone()
|
||||||
|
rgb[~mask] = rgbm[~mask]/12.92
|
||||||
|
return _convert(xyz_from_rgb, rgb)
|
||||||
|
|
||||||
|
def xyz2lab(xyz, illuminant="D65", observer="2"):
|
||||||
|
|
||||||
|
# arr = _prepare_colorarray(xyz)
|
||||||
|
xyz_ref_white = get_xyz_coords(illuminant, observer)
|
||||||
|
#cuda
|
||||||
|
if xyz.is_cuda:
|
||||||
|
xyz_ref_white = xyz_ref_white.cuda()
|
||||||
|
|
||||||
|
# scale by CIE XYZ tristimulus values of the reference white point
|
||||||
|
xyz = xyz / xyz_ref_white.view(1,3,1,1)
|
||||||
|
# Nonlinear distortion and linear transformation
|
||||||
|
mask = xyz > 0.008856
|
||||||
|
xyzm = xyz.clone()
|
||||||
|
xyz[mask] = torch.pow(xyzm[mask], 1/3)
|
||||||
|
xyzm = xyz.clone()
|
||||||
|
xyz[~mask] = 7.787 * xyzm[~mask] + 16. / 116.
|
||||||
|
x, y, z = xyz[:, 0, :, :], xyz[:, 1, :, :], xyz[:, 2, :, :]
|
||||||
|
# Vector scaling
|
||||||
|
L = (116. * y) - 16.
|
||||||
|
a = 500.0 * (x - y)
|
||||||
|
b = 200.0 * (y - z)
|
||||||
|
return torch.stack((L,a,b), 1)
|
||||||
|
|
||||||
|
def rgb2lab(rgb, illuminant="D65", observer="2"):
|
||||||
|
"""RGB to lab color space conversion.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
rgb : array_like
|
||||||
|
The image in RGB format, in a 3- or 4-D array of shape
|
||||||
|
``(.., ..,[ ..,] 3)``.
|
||||||
|
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
|
||||||
|
The name of the illuminant (the function is NOT case sensitive).
|
||||||
|
observer : {"2", "10"}, optional
|
||||||
|
The aperture angle of the observer.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
out : ndarray
|
||||||
|
The image in Lab format, in a 3- or 4-D array of shape
|
||||||
|
``(.., ..,[ ..,] 3)``.
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `rgb` is not a 3- or 4-D array of shape ``(.., ..,[ ..,] 3)``.
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
This function uses rgb2xyz and xyz2lab.
|
||||||
|
By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values
|
||||||
|
x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for
|
||||||
|
a list of supported illuminants.
|
||||||
|
"""
|
||||||
|
return xyz2lab(rgb2xyz(rgb), illuminant, observer)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def lab2xyz(lab, illuminant="D65", observer="2"):
|
||||||
|
arr = lab.clone()
|
||||||
|
L, a, b = arr[:, 0, :, :], arr[:, 1, :, :], arr[:, 2, :, :]
|
||||||
|
y = (L + 16.) / 116.
|
||||||
|
x = (a / 500.) + y
|
||||||
|
z = y - (b / 200.)
|
||||||
|
|
||||||
|
# if (z < 0).sum() > 0:
|
||||||
|
# warn('Color data out of range: Z < 0 in %s pixels' % (z < 0).sum().item())
|
||||||
|
# z[z < 0] = 0 # NO GRADIENT!!!!
|
||||||
|
|
||||||
|
out = torch.stack((x, y, z),1)
|
||||||
|
|
||||||
|
mask = out > 0.2068966
|
||||||
|
outm = out.clone()
|
||||||
|
out[mask] = torch.pow(outm[mask], 3.)
|
||||||
|
outm = out.clone()
|
||||||
|
out[~mask] = (outm[~mask] - 16.0 / 116.) / 7.787
|
||||||
|
|
||||||
|
# rescale to the reference white (illuminant)
|
||||||
|
xyz_ref_white = get_xyz_coords(illuminant, observer)
|
||||||
|
# cuda
|
||||||
|
if lab.is_cuda:
|
||||||
|
xyz_ref_white = xyz_ref_white.cuda()
|
||||||
|
xyz_ref_white = xyz_ref_white.unsqueeze(2).unsqueeze(2).repeat(1,1,out.shape[2],out.shape[3])
|
||||||
|
out = out * xyz_ref_white
|
||||||
|
return out
|
||||||
|
|
||||||
|
def xyz2rgb(xyz):
|
||||||
|
arr = _convert(rgb_from_xyz, xyz)
|
||||||
|
mask = arr > 0.0031308
|
||||||
|
arrm = arr.clone()
|
||||||
|
arr[mask] = 1.055 * torch.pow(arrm[mask], 1 / 2.4) - 0.055
|
||||||
|
arrm = arr.clone()
|
||||||
|
arr[~mask] = arrm[~mask] * 12.92
|
||||||
|
|
||||||
|
# CLAMP KILLS GRADIENTS
|
||||||
|
# mask_z = arr < 0
|
||||||
|
# arr[mask_z] = 0
|
||||||
|
# mask_o = arr > 1
|
||||||
|
# arr[mask_o] = 1
|
||||||
|
|
||||||
|
# torch.clamp(arr, 0, 1, out=arr)
|
||||||
|
return arr
|
||||||
|
|
||||||
|
def lab2rgb(lab, illuminant="D65", observer="2"):
|
||||||
|
"""Lab to RGB color space conversion.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
lab : array_like
|
||||||
|
The image in Lab format, in a 3-D array of shape ``(.., .., 3)``.
|
||||||
|
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
|
||||||
|
The name of the illuminant (the function is NOT case sensitive).
|
||||||
|
observer : {"2", "10"}, optional
|
||||||
|
The aperture angle of the observer.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
out : ndarray
|
||||||
|
The image in RGB format, in a 3-D array of shape ``(.., .., 3)``.
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `lab` is not a 3-D array of shape ``(.., .., 3)``.
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
This function uses lab2xyz and xyz2rgb.
|
||||||
|
By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values
|
||||||
|
x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for
|
||||||
|
a list of supported illuminants.
|
||||||
|
"""
|
||||||
|
return xyz2rgb(lab2xyz(lab, illuminant, observer))
|
136
codes/utils/fdpl_util.py
Normal file
136
codes/utils/fdpl_util.py
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# note: all dct related functions are either exactly as or based on those
|
||||||
|
# at https://github.com/zh217/torch-dct
|
||||||
|
def dct(x, norm=None):
|
||||||
|
"""
|
||||||
|
Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
||||||
|
For the meaning of the parameter `norm`, see:
|
||||||
|
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||||
|
:param x: the input signal
|
||||||
|
:param norm: the normalization, None or 'ortho'
|
||||||
|
:return: the DCT-II of the signal over the last dimension
|
||||||
|
"""
|
||||||
|
x_shape = x.shape
|
||||||
|
N = x_shape[-1]
|
||||||
|
x = x.contiguous().view(-1, N)
|
||||||
|
|
||||||
|
v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
|
||||||
|
|
||||||
|
Vc = torch.rfft(v, 1, onesided=False)
|
||||||
|
|
||||||
|
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
|
||||||
|
W_r = torch.cos(k)
|
||||||
|
W_i = torch.sin(k)
|
||||||
|
|
||||||
|
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
|
||||||
|
|
||||||
|
if norm == 'ortho':
|
||||||
|
V[:, 0] /= np.sqrt(N) * 2
|
||||||
|
V[:, 1:] /= np.sqrt(N / 2) * 2
|
||||||
|
|
||||||
|
V = 2 * V.view(*x_shape)
|
||||||
|
|
||||||
|
return V
|
||||||
|
|
||||||
|
def idct(X, norm=None):
|
||||||
|
"""
|
||||||
|
The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
||||||
|
Our definition of idct is that idct(dct(x)) == x
|
||||||
|
For the meaning of the parameter `norm`, see:
|
||||||
|
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||||
|
:param X: the input signal
|
||||||
|
:param norm: the normalization, None or 'ortho'
|
||||||
|
:return: the inverse DCT-II of the signal over the last dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
x_shape = X.shape
|
||||||
|
N = x_shape[-1]
|
||||||
|
|
||||||
|
X_v = X.contiguous().view(-1, x_shape[-1]) / 2
|
||||||
|
|
||||||
|
if norm == 'ortho':
|
||||||
|
X_v[:, 0] *= np.sqrt(N) * 2
|
||||||
|
X_v[:, 1:] *= np.sqrt(N / 2) * 2
|
||||||
|
|
||||||
|
k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
|
||||||
|
W_r = torch.cos(k)
|
||||||
|
W_i = torch.sin(k)
|
||||||
|
|
||||||
|
V_t_r = X_v
|
||||||
|
V_t_r = V_t_r.to(device)
|
||||||
|
V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
|
||||||
|
V_t_i = V_t_i.to(device)
|
||||||
|
|
||||||
|
V_r = V_t_r * W_r - V_t_i * W_i
|
||||||
|
V_i = V_t_r * W_i + V_t_i * W_r
|
||||||
|
|
||||||
|
V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
|
||||||
|
|
||||||
|
v = torch.irfft(V, 1, onesided=False)
|
||||||
|
x = v.new_zeros(v.shape)
|
||||||
|
x[:, ::2] += v[:, :N - (N // 2)]
|
||||||
|
x[:, 1::2] += v.flip([1])[:, :N // 2]
|
||||||
|
|
||||||
|
return x.view(*x_shape)
|
||||||
|
|
||||||
|
def dct_2d(x, norm=None):
|
||||||
|
"""
|
||||||
|
2-dimensional Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
||||||
|
For the meaning of the parameter `norm`, see:
|
||||||
|
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||||
|
:param x: the input signal
|
||||||
|
:param norm: the normalization, None or 'ortho'
|
||||||
|
:return: the DCT-II of the signal over the last 2 dimensions
|
||||||
|
"""
|
||||||
|
X1 = dct(x, norm=norm)
|
||||||
|
X2 = dct(X1.transpose(-1, -2), norm=norm)
|
||||||
|
return X2.transpose(-1, -2)
|
||||||
|
|
||||||
|
def idct_2d(X, norm=None):
|
||||||
|
"""
|
||||||
|
The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
||||||
|
Our definition of idct is that idct_2d(dct_2d(x)) == x
|
||||||
|
For the meaning of the parameter `norm`, see:
|
||||||
|
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||||
|
:param X: the input signal
|
||||||
|
:param norm: the normalization, None or 'ortho'
|
||||||
|
:return: the DCT-II of the signal over the last 2 dimensions
|
||||||
|
"""
|
||||||
|
x1 = idct(X, norm=norm)
|
||||||
|
x2 = idct(x1.transpose(-1, -2), norm=norm)
|
||||||
|
return x2.transpose(-1, -2)
|
||||||
|
|
||||||
|
def extract_patches_2d(img,patch_shape,step=[1.0,1.0],batch_first=False):
|
||||||
|
"""
|
||||||
|
source: https://gist.github.com/dem123456789/23f18fd78ac8da9615c347905e64fc78
|
||||||
|
"""
|
||||||
|
patch_H, patch_W = patch_shape[0], patch_shape[1]
|
||||||
|
if(img.size(2) < patch_H):
|
||||||
|
num_padded_H_Top = (patch_H - img.size(2))//2
|
||||||
|
num_padded_H_Bottom = patch_H - img.size(2) - num_padded_H_Top
|
||||||
|
padding_H = nn.ConstantPad2d((0, 0, num_padded_H_Top, num_padded_H_Bottom), 0)
|
||||||
|
img = padding_H(img)
|
||||||
|
if(img.size(3) < patch_W):
|
||||||
|
num_padded_W_Left = (patch_W - img.size(3))//2
|
||||||
|
num_padded_W_Right = patch_W - img.size(3) - num_padded_W_Left
|
||||||
|
padding_W = nn.ConstantPad2d((num_padded_W_Left,num_padded_W_Right, 0, 0), 0)
|
||||||
|
img = padding_W(img)
|
||||||
|
step_int = [0, 0]
|
||||||
|
step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0]
|
||||||
|
step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1]
|
||||||
|
patches_fold_H = img.unfold(2, patch_H, step_int[0])
|
||||||
|
if((img.size(2) - patch_H) % step_int[0] != 0):
|
||||||
|
patches_fold_H = torch.cat((patches_fold_H,
|
||||||
|
img[:, :, -patch_H:, :].permute(0,1,3,2).unsqueeze(2)),dim=2)
|
||||||
|
patches_fold_HW = patches_fold_H.unfold(3, patch_W, step_int[1])
|
||||||
|
if((img.size(3) - patch_W) % step_int[1] != 0):
|
||||||
|
patches_fold_HW = torch.cat((patches_fold_HW,
|
||||||
|
patches_fold_H[:, :, :, -patch_W:, :].permute(0, 1, 2, 4, 3).unsqueeze(3)), dim=3)
|
||||||
|
patches = patches_fold_HW.permute(2, 3, 0, 1, 4, 5)
|
||||||
|
patches = patches.reshape(-1, img.size(0), img.size(1), patch_H, patch_W)
|
||||||
|
if(batch_first):
|
||||||
|
patches = patches.permute(1, 0, 2, 3, 4)
|
||||||
|
return patches
|
Loading…
Reference in New Issue
Block a user