DL-Art-School/codes/utils/colors.py
James Betker 7629cb0e61 Add FDPL Loss
New loss type that can replace PSNR loss. Works against the frequency domain
and focuses on frequency features loss during hr->lr conversion.
2020-07-30 20:47:57 -06:00

410 lines
12 KiB
Python

# Differentiable color conversions from https://github.com/TheZino/pytorch-color-conversions
from warnings import warn
import numpy as np
import torch
from scipy import linalg
xyz_from_rgb = torch.Tensor([[0.412453, 0.357580, 0.180423],
[0.212671, 0.715160, 0.072169],
[0.019334, 0.119193, 0.950227]])
rgb_from_xyz = torch.Tensor(linalg.inv(xyz_from_rgb))
illuminants = \
{"A": {'2': torch.Tensor([(1.098466069456375, 1, 0.3558228003436005)]),
'10': torch.Tensor([(1.111420406956693, 1, 0.3519978321919493)])},
"D50": {'2': torch.Tensor([(0.9642119944211994, 1, 0.8251882845188288)]),
'10': torch.Tensor([(0.9672062750333777, 1, 0.8142801513128616)])},
"D55": {'2': torch.Tensor([(0.956797052643698, 1, 0.9214805860173273)]),
'10': torch.Tensor([(0.9579665682254781, 1, 0.9092525159847462)])},
"D65": {'2': torch.Tensor([(0.95047, 1., 1.08883)]), # This was: `lab_ref_white`
'10': torch.Tensor([(0.94809667673716, 1, 1.0730513595166162)])},
"D75": {'2': torch.Tensor([(0.9497220898840717, 1, 1.226393520724154)]),
'10': torch.Tensor([(0.9441713925645873, 1, 1.2064272211720228)])},
"E": {'2': torch.Tensor([(1.0, 1.0, 1.0)]),
'10': torch.Tensor([(1.0, 1.0, 1.0)])}}
# -------------------------------------------------------------
# The conversion functions that make use of the matrices above
# -------------------------------------------------------------
##### RGB - YCbCr
# Helper for the creation of module-global constant tensors
def _t(data):
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # TODO inherit this
device = torch.device("cpu") # TODO inherit this
return torch.tensor(data, requires_grad=False, dtype=torch.float32, device=device)
# Helper for color matrix multiplication
def _mul(coeffs, image):
# This is implementation is clearly suboptimal. The function will
# be implemented with 'einsum' when a bug in pytorch 0.4.0 will be
# fixed (Einsum modifies variables in-place #7763).
coeffs = coeffs.to(image.device)
r0 = image[:, 0:1, :, :].repeat(1, 3, 1, 1) * coeffs[:, 0].view(1, 3, 1, 1)
r1 = image[:, 1:2, :, :].repeat(1, 3, 1, 1) * coeffs[:, 1].view(1, 3, 1, 1)
r2 = image[:, 2:3, :, :].repeat(1, 3, 1, 1) * coeffs[:, 2].view(1, 3, 1, 1)
return r0 + r1 + r2
# return torch.einsum("dc,bcij->bdij", (coeffs.to(image.device), image))
_RGB_TO_YCBCR = _t([[0.257, 0.504, 0.098], [-0.148, -0.291, 0.439], [0.439 , -0.368, -0.071]])
_YCBCR_OFF = _t([0.063, 0.502, 0.502]).view(1, 3, 1, 1)
def rgb2ycbcr(rgb):
"""sRGB to YCbCr conversion."""
clip_rgb=False
if clip_rgb:
rgb = torch.clamp(rgb, 0, 1)
return _mul(_RGB_TO_YCBCR, rgb) + _YCBCR_OFF.to(rgb.device)
def ycbcr2rgb(rgb):
"""YCbCr to sRGB conversion."""
clip_rgb=False
rgb = _mul(torch.inverse(_RGB_TO_YCBCR), rgb - _YCBCR_OFF.to(rgb.device))
if clip_rgb:
rgb = torch.clamp(rgb, 0, 1)
return rgb
##### HSV - RGB
def rgb2hsv(rgb):
"""
R, G and B input range = 0 ÷ 1.0
H, S and V output range = 0 ÷ 1.0
"""
eps = 1e-7
var_R = rgb[:,0,:,:]
var_G = rgb[:,1,:,:]
var_B = rgb[:,2,:,:]
var_Min = rgb.min(1)[0] #Min. value of RGB
var_Max = rgb.max(1)[0] #Max. value of RGB
del_Max = var_Max - var_Min ##Delta RGB value
H = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
S = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
V = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device)
V = var_Max
#This is a gray, no chroma...
mask = del_Max == 0
H[mask] = 0
S[mask] = 0
#Chromatic data...
S = del_Max / (var_Max + eps)
del_R = ( ( ( var_Max - var_R ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
del_G = ( ( ( var_Max - var_G ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
del_B = ( ( ( var_Max - var_B ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps)
H = torch.where( var_R == var_Max , del_B - del_G, H)
H = torch.where( var_G == var_Max , ( 1 / 3 ) + del_R - del_B, H)
H = torch.where( var_B == var_Max ,( 2 / 3 ) + del_G - del_R, H)
# if ( H < 0 ) H += 1
# if ( H > 1 ) H -= 1
return torch.stack([H, S, V], 1)
def hsv2rgb(hsv):
"""
H, S and V input range = 0 ÷ 1.0
R, G and B output range = 0 ÷ 1.0
"""
eps = 1e-7
bb,cc,hh,ww = hsv.shape
H = hsv[:,0,:,:]
S = hsv[:,1,:,:]
V = hsv[:,2,:,:]
# var_h = torch.zeros(bb,hh,ww)
# var_s = torch.zeros(bb,hh,ww)
# var_v = torch.zeros(bb,hh,ww)
# var_r = torch.zeros(bb,hh,ww)
# var_g = torch.zeros(bb,hh,ww)
# var_b = torch.zeros(bb,hh,ww)
# Grayscale
if (S == 0).all():
R = V
G = V
B = V
# Chromatic data
else:
var_h = H * 6
var_h[var_h == 6] = 0 #H must be < 1
var_i = var_h.floor() #Or ... var_i = floor( var_h )
var_1 = V * ( 1 - S )
var_2 = V * ( 1 - S * ( var_h - var_i ) )
var_3 = V * ( 1 - S * ( 1 - ( var_h - var_i ) ) )
# else { var_r = V ; var_g = var_1 ; var_b = var_2 }
var_r = V
var_g = var_1
var_b = var_2
# var_i == 0 { var_r = V ; var_g = var_3 ; var_b = var_1 }
var_r = torch.where(var_i == 0, V, var_r)
var_g = torch.where(var_i == 0, var_3, var_g)
var_b = torch.where(var_i == 0, var_1, var_b)
# else if ( var_i == 1 ) { var_r = var_2 ; var_g = V ; var_b = var_1 }
var_r = torch.where(var_i == 1, var_2, var_r)
var_g = torch.where(var_i == 1, V, var_g)
var_b = torch.where(var_i == 1, var_1, var_b)
# else if ( var_i == 2 ) { var_r = var_1 ; var_g = V ; var_b = var_3 }
var_r = torch.where(var_i == 2, var_1, var_r)
var_g = torch.where(var_i == 2, V, var_g)
var_b = torch.where(var_i == 2, var_3, var_b)
# else if ( var_i == 3 ) { var_r = var_1 ; var_g = var_2 ; var_b = V }
var_r = torch.where(var_i == 3, var_1, var_r)
var_g = torch.where(var_i == 3, var_2, var_g)
var_b = torch.where(var_i == 3, V, var_b)
# else if ( var_i == 4 ) { var_r = var_3 ; var_g = var_1 ; var_b = V }
var_r = torch.where(var_i == 4, var_3, var_r)
var_g = torch.where(var_i == 4, var_1, var_g)
var_b = torch.where(var_i == 4, V, var_b)
R = var_r #* 255
G = var_g #* 255
B = var_b #* 255
return torch.stack([R, G, B], 1)
##### LAB - RGB
def _convert(matrix, arr):
"""Do the color space conversion.
Parameters
----------
matrix : array_like
The 3x3 matrix to use.
arr : array_like
The input array.
Returns
-------
out : ndarray, dtype=float
The converted array.
"""
if arr.is_cuda:
matrix = matrix.cuda()
bs, ch, h, w = arr.shape
arr = arr.permute((0,2,3,1))
arr = arr.contiguous().view(-1,1,3)
matrix = matrix.transpose(0,1).unsqueeze(0)
matrix = matrix.repeat(arr.shape[0],1,1)
res = torch.bmm(arr,matrix)
res = res.view(bs,h,w,ch)
res = res.transpose(3,2).transpose(2,1)
return res
def get_xyz_coords(illuminant, observer):
"""Get the XYZ coordinates of the given illuminant and observer [1]_.
Parameters
----------
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
The name of the illuminant (the function is NOT case sensitive).
observer : {"2", "10"}, optional
The aperture angle of the observer.
Returns
-------
(x, y, z) : tuple
A tuple with 3 elements containing the XYZ coordinates of the given
illuminant.
Raises
------
ValueError
If either the illuminant or the observer angle are not supported or
unknown.
References
----------
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
"""
illuminant = illuminant.upper()
try:
return illuminants[illuminant][observer]
except KeyError:
raise ValueError("Unknown illuminant/observer combination\
(\'{0}\', \'{1}\')".format(illuminant, observer))
def rgb2xyz(rgb):
mask = rgb > 0.04045
rgbm = rgb.clone()
tmp = torch.pow((rgb + 0.055) / 1.055, 2.4)
rgb = torch.where(mask, tmp, rgb)
rgbm = rgb.clone()
rgb[~mask] = rgbm[~mask]/12.92
return _convert(xyz_from_rgb, rgb)
def xyz2lab(xyz, illuminant="D65", observer="2"):
# arr = _prepare_colorarray(xyz)
xyz_ref_white = get_xyz_coords(illuminant, observer)
#cuda
if xyz.is_cuda:
xyz_ref_white = xyz_ref_white.cuda()
# scale by CIE XYZ tristimulus values of the reference white point
xyz = xyz / xyz_ref_white.view(1,3,1,1)
# Nonlinear distortion and linear transformation
mask = xyz > 0.008856
xyzm = xyz.clone()
xyz[mask] = torch.pow(xyzm[mask], 1/3)
xyzm = xyz.clone()
xyz[~mask] = 7.787 * xyzm[~mask] + 16. / 116.
x, y, z = xyz[:, 0, :, :], xyz[:, 1, :, :], xyz[:, 2, :, :]
# Vector scaling
L = (116. * y) - 16.
a = 500.0 * (x - y)
b = 200.0 * (y - z)
return torch.stack((L,a,b), 1)
def rgb2lab(rgb, illuminant="D65", observer="2"):
"""RGB to lab color space conversion.
Parameters
----------
rgb : array_like
The image in RGB format, in a 3- or 4-D array of shape
``(.., ..,[ ..,] 3)``.
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
The name of the illuminant (the function is NOT case sensitive).
observer : {"2", "10"}, optional
The aperture angle of the observer.
Returns
-------
out : ndarray
The image in Lab format, in a 3- or 4-D array of shape
``(.., ..,[ ..,] 3)``.
Raises
------
ValueError
If `rgb` is not a 3- or 4-D array of shape ``(.., ..,[ ..,] 3)``.
References
----------
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
Notes
-----
This function uses rgb2xyz and xyz2lab.
By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values
x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for
a list of supported illuminants.
"""
return xyz2lab(rgb2xyz(rgb), illuminant, observer)
def lab2xyz(lab, illuminant="D65", observer="2"):
arr = lab.clone()
L, a, b = arr[:, 0, :, :], arr[:, 1, :, :], arr[:, 2, :, :]
y = (L + 16.) / 116.
x = (a / 500.) + y
z = y - (b / 200.)
# if (z < 0).sum() > 0:
# warn('Color data out of range: Z < 0 in %s pixels' % (z < 0).sum().item())
# z[z < 0] = 0 # NO GRADIENT!!!!
out = torch.stack((x, y, z),1)
mask = out > 0.2068966
outm = out.clone()
out[mask] = torch.pow(outm[mask], 3.)
outm = out.clone()
out[~mask] = (outm[~mask] - 16.0 / 116.) / 7.787
# rescale to the reference white (illuminant)
xyz_ref_white = get_xyz_coords(illuminant, observer)
# cuda
if lab.is_cuda:
xyz_ref_white = xyz_ref_white.cuda()
xyz_ref_white = xyz_ref_white.unsqueeze(2).unsqueeze(2).repeat(1,1,out.shape[2],out.shape[3])
out = out * xyz_ref_white
return out
def xyz2rgb(xyz):
arr = _convert(rgb_from_xyz, xyz)
mask = arr > 0.0031308
arrm = arr.clone()
arr[mask] = 1.055 * torch.pow(arrm[mask], 1 / 2.4) - 0.055
arrm = arr.clone()
arr[~mask] = arrm[~mask] * 12.92
# CLAMP KILLS GRADIENTS
# mask_z = arr < 0
# arr[mask_z] = 0
# mask_o = arr > 1
# arr[mask_o] = 1
# torch.clamp(arr, 0, 1, out=arr)
return arr
def lab2rgb(lab, illuminant="D65", observer="2"):
"""Lab to RGB color space conversion.
Parameters
----------
lab : array_like
The image in Lab format, in a 3-D array of shape ``(.., .., 3)``.
illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
The name of the illuminant (the function is NOT case sensitive).
observer : {"2", "10"}, optional
The aperture angle of the observer.
Returns
-------
out : ndarray
The image in RGB format, in a 3-D array of shape ``(.., .., 3)``.
Raises
------
ValueError
If `lab` is not a 3-D array of shape ``(.., .., 3)``.
References
----------
.. [1] https://en.wikipedia.org/wiki/Standard_illuminant
Notes
-----
This function uses lab2xyz and xyz2rgb.
By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values
x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for
a list of supported illuminants.
"""
return xyz2rgb(lab2xyz(lab, illuminant, observer))