7629cb0e61
New loss type that can replace PSNR loss. Works against the frequency domain and focuses on frequency features loss during hr->lr conversion.
410 lines
12 KiB
Python
410 lines
12 KiB
Python
# Differentiable color conversions from https://github.com/TheZino/pytorch-color-conversions
|
|
|
|
from warnings import warn
|
|
|
|
import numpy as np
|
|
import torch
|
|
from scipy import linalg
|
|
|
|
xyz_from_rgb = torch.Tensor([[0.412453, 0.357580, 0.180423],
|
|
[0.212671, 0.715160, 0.072169],
|
|
[0.019334, 0.119193, 0.950227]])
|
|
|
|
rgb_from_xyz = torch.Tensor(linalg.inv(xyz_from_rgb))
|
|
|
|
illuminants = \
|
|
{"A": {'2': torch.Tensor([(1.098466069456375, 1, 0.3558228003436005)]),
|
|
'10': torch.Tensor([(1.111420406956693, 1, 0.3519978321919493)])},
|
|
"D50": {'2': torch.Tensor([(0.9642119944211994, 1, 0.8251882845188288)]),
|
|
'10': torch.Tensor([(0.9672062750333777, 1, 0.8142801513128616)])},
|
|
"D55": {'2': torch.Tensor([(0.956797052643698, 1, 0.9214805860173273)]),
|
|
'10': torch.Tensor([(0.9579665682254781, 1, 0.9092525159847462)])},
|
|
"D65": {'2': torch.Tensor([(0.95047, 1., 1.08883)]), # This was: `lab_ref_white`
|
|
'10': torch.Tensor([(0.94809667673716, 1, 1.0730513595166162)])},
|
|
"D75": {'2': torch.Tensor([(0.9497220898840717, 1, 1.226393520724154)]),
|
|
'10': torch.Tensor([(0.9441713925645873, 1, 1.2064272211720228)])},
|
|
"E": {'2': torch.Tensor([(1.0, 1.0, 1.0)]),
|
|
'10': torch.Tensor([(1.0, 1.0, 1.0)])}}
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
# The conversion functions that make use of the matrices above
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
##### RGB - YCbCr
|
|
|
|
# Helper for the creation of module-global constant tensors
|
|
def _t(data):
|
|
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # TODO inherit this
|
|
device = torch.device("cpu") # TODO inherit this
|
|
return torch.tensor(data, requires_grad=False, dtype=torch.float32, device=device)
|
|
|
|
# Helper for color matrix multiplication
|
|
def _mul(coeffs, image):
|
|
# This is implementation is clearly suboptimal. The function will
|
|
# be implemented with 'einsum' when a bug in pytorch 0.4.0 will be
|
|
# fixed (Einsum modifies variables in-place #7763).
|
|
coeffs = coeffs.to(image.device)
|
|
r0 = image[:, 0:1, :, :].repeat(1, 3, 1, 1) * coeffs[:, 0].view(1, 3, 1, 1)
|
|
r1 = image[:, 1:2, :, :].repeat(1, 3, 1, 1) * coeffs[:, 1].view(1, 3, 1, 1)
|
|
r2 = image[:, 2:3, :, :].repeat(1, 3, 1, 1) * coeffs[:, 2].view(1, 3, 1, 1)
|
|
return r0 + r1 + r2
|
|
# return torch.einsum("dc,bcij->bdij", (coeffs.to(image.device), image))
|
|
|
|
_RGB_TO_YCBCR = _t([[0.257, 0.504, 0.098], [-0.148, -0.291, 0.439], [0.439 , -0.368, -0.071]])
|
|
_YCBCR_OFF = _t([0.063, 0.502, 0.502]).view(1, 3, 1, 1)
|
|
|
|
|
|
def rgb2ycbcr(rgb):
|
|
"""sRGB to YCbCr conversion."""
|
|
clip_rgb=False
|
|
if clip_rgb:
|
|
rgb = torch.clamp(rgb, 0, 1)
|
|
return _mul(_RGB_TO_YCBCR, rgb) + _YCBCR_OFF.to(rgb.device)
|
|
|
|
|
|
def ycbcr2rgb(rgb):
|
|
"""YCbCr to sRGB conversion."""
|
|
clip_rgb=False
|
|
rgb = _mul(torch.inverse(_RGB_TO_YCBCR), rgb - _YCBCR_OFF.to(rgb.device))
|
|
if clip_rgb:
|
|
rgb = torch.clamp(rgb, 0, 1)
|
|
return rgb
|
|
|
|
|
|
##### HSV - RGB
|
|
|
|
def rgb2hsv(rgb):
|
|
"""
|
|
R, G and B input range = 0 ÷ 1.0
|
|
H, S and V output range = 0 ÷ 1.0
|
|
"""
|
|
eps = 1e-7
|
|
|
|
var_R = rgb[:,0,:,:]
|
|
var_G = rgb[:,1,:,:]
|
|
var_B = rgb[:,2,:,:]
|
|
|
|
var_Min = rgb.min(1)[0] #Min. value of RGB
|
|
var_Max = rgb.max(1)[0] #Max. value of RGB
|
|
del_Max = var_Max - var_Min ##Delta RGB value
|
|
|
|
H = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
|
|
S = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
|
|
V = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
|
|
|
|
V = var_Max
|
|
|
|
#This is a gray, no chroma...
|
|
mask = del_Max == 0
|
|
H[mask] = 0
|
|
S[mask] = 0
|
|
|
|
#Chromatic data...
|
|
S = del_Max / (var_Max + eps)
|
|
|
|
del_R = ( ( ( var_Max - var_R ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
|
|
del_G = ( ( ( var_Max - var_G ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
|
|
del_B = ( ( ( var_Max - var_B ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
|
|
|
|
H = torch.where( var_R == var_Max , del_B - del_G, H)
|
|
H = torch.where( var_G == var_Max , ( 1 / 3 ) + del_R - del_B, H)
|
|
H = torch.where( var_B == var_Max ,( 2 / 3 ) + del_G - del_R, H)
|
|
|
|
# if ( H < 0 ) H += 1
|
|
# if ( H > 1 ) H -= 1
|
|
|
|
return torch.stack([H, S, V], 1)
|
|
|
|
def hsv2rgb(hsv):
|
|
"""
|
|
H, S and V input range = 0 ÷ 1.0
|
|
R, G and B output range = 0 ÷ 1.0
|
|
"""
|
|
|
|
eps = 1e-7
|
|
|
|
bb,cc,hh,ww = hsv.shape
|
|
H = hsv[:,0,:,:]
|
|
S = hsv[:,1,:,:]
|
|
V = hsv[:,2,:,:]
|
|
|
|
# var_h = torch.zeros(bb,hh,ww)
|
|
# var_s = torch.zeros(bb,hh,ww)
|
|
# var_v = torch.zeros(bb,hh,ww)
|
|
|
|
# var_r = torch.zeros(bb,hh,ww)
|
|
# var_g = torch.zeros(bb,hh,ww)
|
|
# var_b = torch.zeros(bb,hh,ww)
|
|
|
|
# Grayscale
|
|
if (S == 0).all():
|
|
|
|
R = V
|
|
G = V
|
|
B = V
|
|
|
|
# Chromatic data
|
|
else:
|
|
|
|
var_h = H * 6
|
|
|
|
var_h[var_h == 6] = 0 #H must be < 1
|
|
var_i = var_h.floor() #Or ... var_i = floor( var_h )
|
|
var_1 = V * ( 1 - S )
|
|
var_2 = V * ( 1 - S * ( var_h - var_i ) )
|
|
var_3 = V * ( 1 - S * ( 1 - ( var_h - var_i ) ) )
|
|
|
|
# else { var_r = V ; var_g = var_1 ; var_b = var_2 }
|
|
var_r = V
|
|
var_g = var_1
|
|
var_b = var_2
|
|
|
|
# var_i == 0 { var_r = V ; var_g = var_3 ; var_b = var_1 }
|
|
var_r = torch.where(var_i == 0, V, var_r)
|
|
var_g = torch.where(var_i == 0, var_3, var_g)
|
|
var_b = torch.where(var_i == 0, var_1, var_b)
|
|
|
|
# else if ( var_i == 1 ) { var_r = var_2 ; var_g = V ; var_b = var_1 }
|
|
var_r = torch.where(var_i == 1, var_2, var_r)
|
|
var_g = torch.where(var_i == 1, V, var_g)
|
|
var_b = torch.where(var_i == 1, var_1, var_b)
|
|
|
|
# else if ( var_i == 2 ) { var_r = var_1 ; var_g = V ; var_b = var_3 }
|
|
var_r = torch.where(var_i == 2, var_1, var_r)
|
|
var_g = torch.where(var_i == 2, V, var_g)
|
|
var_b = torch.where(var_i == 2, var_3, var_b)
|
|
|
|
# else if ( var_i == 3 ) { var_r = var_1 ; var_g = var_2 ; var_b = V }
|
|
var_r = torch.where(var_i == 3, var_1, var_r)
|
|
var_g = torch.where(var_i == 3, var_2, var_g)
|
|
var_b = torch.where(var_i == 3, V, var_b)
|
|
|
|
# else if ( var_i == 4 ) { var_r = var_3 ; var_g = var_1 ; var_b = V }
|
|
var_r = torch.where(var_i == 4, var_3, var_r)
|
|
var_g = torch.where(var_i == 4, var_1, var_g)
|
|
var_b = torch.where(var_i == 4, V, var_b)
|
|
|
|
|
|
R = var_r #* 255
|
|
G = var_g #* 255
|
|
B = var_b #* 255
|
|
|
|
|
|
return torch.stack([R, G, B], 1)
|
|
|
|
|
|
##### LAB - RGB
|
|
|
|
def _convert(matrix, arr):
|
|
"""Do the color space conversion.
|
|
Parameters
|
|
----------
|
|
matrix : array_like
|
|
The 3x3 matrix to use.
|
|
arr : array_like
|
|
The input array.
|
|
Returns
|
|
-------
|
|
out : ndarray, dtype=float
|
|
The converted array.
|
|
"""
|
|
|
|
if arr.is_cuda:
|
|
matrix = matrix.cuda()
|
|
|
|
bs, ch, h, w = arr.shape
|
|
|
|
arr = arr.permute((0,2,3,1))
|
|
arr = arr.contiguous().view(-1,1,3)
|
|
|
|
matrix = matrix.transpose(0,1).unsqueeze(0)
|
|
matrix = matrix.repeat(arr.shape[0],1,1)
|
|
|
|
res = torch.bmm(arr,matrix)
|
|
|
|
res = res.view(bs,h,w,ch)
|
|
res = res.transpose(3,2).transpose(2,1)
|
|
|
|
|
|
return res
|
|
|
|
def get_xyz_coords(illuminant, observer):
|
|
"""Get the XYZ coordinates of the given illuminant and observer [1]_.
|
|
Parameters
|
|
----------
|
|
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
|
|
The name of the illuminant (the function is NOT case sensitive).
|
|
observer : {"2", "10"}, optional
|
|
The aperture angle of the observer.
|
|
Returns
|
|
-------
|
|
(x, y, z) : tuple
|
|
A tuple with 3 elements containing the XYZ coordinates of the given
|
|
illuminant.
|
|
Raises
|
|
------
|
|
ValueError
|
|
If either the illuminant or the observer angle are not supported or
|
|
unknown.
|
|
References
|
|
----------
|
|
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
|
|
"""
|
|
illuminant = illuminant.upper()
|
|
try:
|
|
return illuminants[illuminant][observer]
|
|
except KeyError:
|
|
raise ValueError("Unknown illuminant/observer combination\
|
|
(\'{0}\', \'{1}\')".format(illuminant, observer))
|
|
|
|
|
|
|
|
|
|
|
|
def rgb2xyz(rgb):
|
|
|
|
mask = rgb > 0.04045
|
|
rgbm = rgb.clone()
|
|
tmp = torch.pow((rgb + 0.055) / 1.055, 2.4)
|
|
rgb = torch.where(mask, tmp, rgb)
|
|
|
|
rgbm = rgb.clone()
|
|
rgb[~mask] = rgbm[~mask]/12.92
|
|
return _convert(xyz_from_rgb, rgb)
|
|
|
|
def xyz2lab(xyz, illuminant="D65", observer="2"):
|
|
|
|
# arr = _prepare_colorarray(xyz)
|
|
xyz_ref_white = get_xyz_coords(illuminant, observer)
|
|
#cuda
|
|
if xyz.is_cuda:
|
|
xyz_ref_white = xyz_ref_white.cuda()
|
|
|
|
# scale by CIE XYZ tristimulus values of the reference white point
|
|
xyz = xyz / xyz_ref_white.view(1,3,1,1)
|
|
# Nonlinear distortion and linear transformation
|
|
mask = xyz > 0.008856
|
|
xyzm = xyz.clone()
|
|
xyz[mask] = torch.pow(xyzm[mask], 1/3)
|
|
xyzm = xyz.clone()
|
|
xyz[~mask] = 7.787 * xyzm[~mask] + 16. / 116.
|
|
x, y, z = xyz[:, 0, :, :], xyz[:, 1, :, :], xyz[:, 2, :, :]
|
|
# Vector scaling
|
|
L = (116. * y) - 16.
|
|
a = 500.0 * (x - y)
|
|
b = 200.0 * (y - z)
|
|
return torch.stack((L,a,b), 1)
|
|
|
|
def rgb2lab(rgb, illuminant="D65", observer="2"):
|
|
"""RGB to lab color space conversion.
|
|
Parameters
|
|
----------
|
|
rgb : array_like
|
|
The image in RGB format, in a 3- or 4-D array of shape
|
|
``(.., ..,[ ..,] 3)``.
|
|
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
|
|
The name of the illuminant (the function is NOT case sensitive).
|
|
observer : {"2", "10"}, optional
|
|
The aperture angle of the observer.
|
|
Returns
|
|
-------
|
|
out : ndarray
|
|
The image in Lab format, in a 3- or 4-D array of shape
|
|
``(.., ..,[ ..,] 3)``.
|
|
Raises
|
|
------
|
|
ValueError
|
|
If `rgb` is not a 3- or 4-D array of shape ``(.., ..,[ ..,] 3)``.
|
|
References
|
|
----------
|
|
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
|
|
Notes
|
|
-----
|
|
This function uses rgb2xyz and xyz2lab.
|
|
By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values
|
|
x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for
|
|
a list of supported illuminants.
|
|
"""
|
|
return xyz2lab(rgb2xyz(rgb), illuminant, observer)
|
|
|
|
|
|
|
|
|
|
|
|
def lab2xyz(lab, illuminant="D65", observer="2"):
|
|
arr = lab.clone()
|
|
L, a, b = arr[:, 0, :, :], arr[:, 1, :, :], arr[:, 2, :, :]
|
|
y = (L + 16.) / 116.
|
|
x = (a / 500.) + y
|
|
z = y - (b / 200.)
|
|
|
|
# if (z < 0).sum() > 0:
|
|
# warn('Color data out of range: Z < 0 in %s pixels' % (z < 0).sum().item())
|
|
# z[z < 0] = 0 # NO GRADIENT!!!!
|
|
|
|
out = torch.stack((x, y, z),1)
|
|
|
|
mask = out > 0.2068966
|
|
outm = out.clone()
|
|
out[mask] = torch.pow(outm[mask], 3.)
|
|
outm = out.clone()
|
|
out[~mask] = (outm[~mask] - 16.0 / 116.) / 7.787
|
|
|
|
# rescale to the reference white (illuminant)
|
|
xyz_ref_white = get_xyz_coords(illuminant, observer)
|
|
# cuda
|
|
if lab.is_cuda:
|
|
xyz_ref_white = xyz_ref_white.cuda()
|
|
xyz_ref_white = xyz_ref_white.unsqueeze(2).unsqueeze(2).repeat(1,1,out.shape[2],out.shape[3])
|
|
out = out * xyz_ref_white
|
|
return out
|
|
|
|
def xyz2rgb(xyz):
|
|
arr = _convert(rgb_from_xyz, xyz)
|
|
mask = arr > 0.0031308
|
|
arrm = arr.clone()
|
|
arr[mask] = 1.055 * torch.pow(arrm[mask], 1 / 2.4) - 0.055
|
|
arrm = arr.clone()
|
|
arr[~mask] = arrm[~mask] * 12.92
|
|
|
|
# CLAMP KILLS GRADIENTS
|
|
# mask_z = arr < 0
|
|
# arr[mask_z] = 0
|
|
# mask_o = arr > 1
|
|
# arr[mask_o] = 1
|
|
|
|
# torch.clamp(arr, 0, 1, out=arr)
|
|
return arr
|
|
|
|
def lab2rgb(lab, illuminant="D65", observer="2"):
|
|
"""Lab to RGB color space conversion.
|
|
Parameters
|
|
----------
|
|
lab : array_like
|
|
The image in Lab format, in a 3-D array of shape ``(.., .., 3)``.
|
|
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
|
|
The name of the illuminant (the function is NOT case sensitive).
|
|
observer : {"2", "10"}, optional
|
|
The aperture angle of the observer.
|
|
Returns
|
|
-------
|
|
out : ndarray
|
|
The image in RGB format, in a 3-D array of shape ``(.., .., 3)``.
|
|
Raises
|
|
------
|
|
ValueError
|
|
If `lab` is not a 3-D array of shape ``(.., .., 3)``.
|
|
References
|
|
----------
|
|
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
|
|
Notes
|
|
-----
|
|
This function uses lab2xyz and xyz2rgb.
|
|
By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values
|
|
x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for
|
|
a list of supported illuminants.
|
|
"""
|
|
return xyz2rgb(lab2xyz(lab, illuminant, observer))
|