# Differentiable color conversions from https://github.com/TheZino/pytorch-color-conversions from warnings import warn import numpy as np import torch from scipy import linalg xyz_from_rgb = torch.Tensor([[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169], [0.019334, 0.119193, 0.950227]]) rgb_from_xyz = torch.Tensor(linalg.inv(xyz_from_rgb)) illuminants = \ {"A": {'2': torch.Tensor([(1.098466069456375, 1, 0.3558228003436005)]), '10': torch.Tensor([(1.111420406956693, 1, 0.3519978321919493)])}, "D50": {'2': torch.Tensor([(0.9642119944211994, 1, 0.8251882845188288)]), '10': torch.Tensor([(0.9672062750333777, 1, 0.8142801513128616)])}, "D55": {'2': torch.Tensor([(0.956797052643698, 1, 0.9214805860173273)]), '10': torch.Tensor([(0.9579665682254781, 1, 0.9092525159847462)])}, "D65": {'2': torch.Tensor([(0.95047, 1., 1.08883)]), # This was: `lab_ref_white` '10': torch.Tensor([(0.94809667673716, 1, 1.0730513595166162)])}, "D75": {'2': torch.Tensor([(0.9497220898840717, 1, 1.226393520724154)]), '10': torch.Tensor([(0.9441713925645873, 1, 1.2064272211720228)])}, "E": {'2': torch.Tensor([(1.0, 1.0, 1.0)]), '10': torch.Tensor([(1.0, 1.0, 1.0)])}} # ------------------------------------------------------------- # The conversion functions that make use of the matrices above # ------------------------------------------------------------- ##### RGB - YCbCr # Helper for the creation of module-global constant tensors def _t(data): # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # TODO inherit this device = torch.device("cpu") # TODO inherit this return torch.tensor(data, requires_grad=False, dtype=torch.float32, device=device) # Helper for color matrix multiplication def _mul(coeffs, image): # This is implementation is clearly suboptimal. The function will # be implemented with 'einsum' when a bug in pytorch 0.4.0 will be # fixed (Einsum modifies variables in-place #7763). coeffs = coeffs.to(image.device) r0 = image[:, 0:1, :, :].repeat(1, 3, 1, 1) * coeffs[:, 0].view(1, 3, 1, 1) r1 = image[:, 1:2, :, :].repeat(1, 3, 1, 1) * coeffs[:, 1].view(1, 3, 1, 1) r2 = image[:, 2:3, :, :].repeat(1, 3, 1, 1) * coeffs[:, 2].view(1, 3, 1, 1) return r0 + r1 + r2 # return torch.einsum("dc,bcij->bdij", (coeffs.to(image.device), image)) _RGB_TO_YCBCR = _t([[0.257, 0.504, 0.098], [-0.148, -0.291, 0.439], [0.439 , -0.368, -0.071]]) _YCBCR_OFF = _t([0.063, 0.502, 0.502]).view(1, 3, 1, 1) def rgb2ycbcr(rgb): """sRGB to YCbCr conversion.""" clip_rgb=False if clip_rgb: rgb = torch.clamp(rgb, 0, 1) return _mul(_RGB_TO_YCBCR, rgb) + _YCBCR_OFF.to(rgb.device) def ycbcr2rgb(rgb): """YCbCr to sRGB conversion.""" clip_rgb=False rgb = _mul(torch.inverse(_RGB_TO_YCBCR), rgb - _YCBCR_OFF.to(rgb.device)) if clip_rgb: rgb = torch.clamp(rgb, 0, 1) return rgb ##### HSV - RGB def rgb2hsv(rgb): """ R, G and B input range = 0 ÷ 1.0 H, S and V output range = 0 ÷ 1.0 """ eps = 1e-7 var_R = rgb[:,0,:,:] var_G = rgb[:,1,:,:] var_B = rgb[:,2,:,:] var_Min = rgb.min(1)[0] #Min. value of RGB var_Max = rgb.max(1)[0] #Max. value of RGB del_Max = var_Max - var_Min ##Delta RGB value H = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device) S = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device) V = torch.zeros([rgb.shape[0], rgb.shape[2], rgb.shape[3]]).to(rgb.device) V = var_Max #This is a gray, no chroma... mask = del_Max == 0 H[mask] = 0 S[mask] = 0 #Chromatic data... S = del_Max / (var_Max + eps) del_R = ( ( ( var_Max - var_R ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps) del_G = ( ( ( var_Max - var_G ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps) del_B = ( ( ( var_Max - var_B ) / 6 ) + ( del_Max / 2 ) ) / (del_Max + eps) H = torch.where( var_R == var_Max , del_B - del_G, H) H = torch.where( var_G == var_Max , ( 1 / 3 ) + del_R - del_B, H) H = torch.where( var_B == var_Max ,( 2 / 3 ) + del_G - del_R, H) # if ( H < 0 ) H += 1 # if ( H > 1 ) H -= 1 return torch.stack([H, S, V], 1) def hsv2rgb(hsv): """ H, S and V input range = 0 ÷ 1.0 R, G and B output range = 0 ÷ 1.0 """ eps = 1e-7 bb,cc,hh,ww = hsv.shape H = hsv[:,0,:,:] S = hsv[:,1,:,:] V = hsv[:,2,:,:] # var_h = torch.zeros(bb,hh,ww) # var_s = torch.zeros(bb,hh,ww) # var_v = torch.zeros(bb,hh,ww) # var_r = torch.zeros(bb,hh,ww) # var_g = torch.zeros(bb,hh,ww) # var_b = torch.zeros(bb,hh,ww) # Grayscale if (S == 0).all(): R = V G = V B = V # Chromatic data else: var_h = H * 6 var_h[var_h == 6] = 0 #H must be < 1 var_i = var_h.floor() #Or ... var_i = floor( var_h ) var_1 = V * ( 1 - S ) var_2 = V * ( 1 - S * ( var_h - var_i ) ) var_3 = V * ( 1 - S * ( 1 - ( var_h - var_i ) ) ) # else { var_r = V ; var_g = var_1 ; var_b = var_2 } var_r = V var_g = var_1 var_b = var_2 # var_i == 0 { var_r = V ; var_g = var_3 ; var_b = var_1 } var_r = torch.where(var_i == 0, V, var_r) var_g = torch.where(var_i == 0, var_3, var_g) var_b = torch.where(var_i == 0, var_1, var_b) # else if ( var_i == 1 ) { var_r = var_2 ; var_g = V ; var_b = var_1 } var_r = torch.where(var_i == 1, var_2, var_r) var_g = torch.where(var_i == 1, V, var_g) var_b = torch.where(var_i == 1, var_1, var_b) # else if ( var_i == 2 ) { var_r = var_1 ; var_g = V ; var_b = var_3 } var_r = torch.where(var_i == 2, var_1, var_r) var_g = torch.where(var_i == 2, V, var_g) var_b = torch.where(var_i == 2, var_3, var_b) # else if ( var_i == 3 ) { var_r = var_1 ; var_g = var_2 ; var_b = V } var_r = torch.where(var_i == 3, var_1, var_r) var_g = torch.where(var_i == 3, var_2, var_g) var_b = torch.where(var_i == 3, V, var_b) # else if ( var_i == 4 ) { var_r = var_3 ; var_g = var_1 ; var_b = V } var_r = torch.where(var_i == 4, var_3, var_r) var_g = torch.where(var_i == 4, var_1, var_g) var_b = torch.where(var_i == 4, V, var_b) R = var_r #* 255 G = var_g #* 255 B = var_b #* 255 return torch.stack([R, G, B], 1) ##### LAB - RGB def _convert(matrix, arr): """Do the color space conversion. Parameters ---------- matrix : array_like The 3x3 matrix to use. arr : array_like The input array. Returns ------- out : ndarray, dtype=float The converted array. """ if arr.is_cuda: matrix = matrix.cuda() bs, ch, h, w = arr.shape arr = arr.permute((0,2,3,1)) arr = arr.contiguous().view(-1,1,3) matrix = matrix.transpose(0,1).unsqueeze(0) matrix = matrix.repeat(arr.shape[0],1,1) res = torch.bmm(arr,matrix) res = res.view(bs,h,w,ch) res = res.transpose(3,2).transpose(2,1) return res def get_xyz_coords(illuminant, observer): """Get the XYZ coordinates of the given illuminant and observer [1]_. Parameters ---------- illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional The name of the illuminant (the function is NOT case sensitive). observer : {"2", "10"}, optional The aperture angle of the observer. Returns ------- (x, y, z) : tuple A tuple with 3 elements containing the XYZ coordinates of the given illuminant. Raises ------ ValueError If either the illuminant or the observer angle are not supported or unknown. References ---------- .. [1] https://en.wikipedia.org/wiki/Standard_illuminant """ illuminant = illuminant.upper() try: return illuminants[illuminant][observer] except KeyError: raise ValueError("Unknown illuminant/observer combination\ (\'{0}\', \'{1}\')".format(illuminant, observer)) def rgb2xyz(rgb): mask = rgb > 0.04045 rgbm = rgb.clone() tmp = torch.pow((rgb + 0.055) / 1.055, 2.4) rgb = torch.where(mask, tmp, rgb) rgbm = rgb.clone() rgb[~mask] = rgbm[~mask]/12.92 return _convert(xyz_from_rgb, rgb) def xyz2lab(xyz, illuminant="D65", observer="2"): # arr = _prepare_colorarray(xyz) xyz_ref_white = get_xyz_coords(illuminant, observer) #cuda if xyz.is_cuda: xyz_ref_white = xyz_ref_white.cuda() # scale by CIE XYZ tristimulus values of the reference white point xyz = xyz / xyz_ref_white.view(1,3,1,1) # Nonlinear distortion and linear transformation mask = xyz > 0.008856 xyzm = xyz.clone() xyz[mask] = torch.pow(xyzm[mask], 1/3) xyzm = xyz.clone() xyz[~mask] = 7.787 * xyzm[~mask] + 16. / 116. x, y, z = xyz[:, 0, :, :], xyz[:, 1, :, :], xyz[:, 2, :, :] # Vector scaling L = (116. * y) - 16. a = 500.0 * (x - y) b = 200.0 * (y - z) return torch.stack((L,a,b), 1) def rgb2lab(rgb, illuminant="D65", observer="2"): """RGB to lab color space conversion. Parameters ---------- rgb : array_like The image in RGB format, in a 3- or 4-D array of shape ``(.., ..,[ ..,] 3)``. illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional The name of the illuminant (the function is NOT case sensitive). observer : {"2", "10"}, optional The aperture angle of the observer. Returns ------- out : ndarray The image in Lab format, in a 3- or 4-D array of shape ``(.., ..,[ ..,] 3)``. Raises ------ ValueError If `rgb` is not a 3- or 4-D array of shape ``(.., ..,[ ..,] 3)``. References ---------- .. [1] https://en.wikipedia.org/wiki/Standard_illuminant Notes ----- This function uses rgb2xyz and xyz2lab. By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for a list of supported illuminants. """ return xyz2lab(rgb2xyz(rgb), illuminant, observer) def lab2xyz(lab, illuminant="D65", observer="2"): arr = lab.clone() L, a, b = arr[:, 0, :, :], arr[:, 1, :, :], arr[:, 2, :, :] y = (L + 16.) / 116. x = (a / 500.) + y z = y - (b / 200.) # if (z < 0).sum() > 0: # warn('Color data out of range: Z < 0 in %s pixels' % (z < 0).sum().item()) # z[z < 0] = 0 # NO GRADIENT!!!! out = torch.stack((x, y, z),1) mask = out > 0.2068966 outm = out.clone() out[mask] = torch.pow(outm[mask], 3.) outm = out.clone() out[~mask] = (outm[~mask] - 16.0 / 116.) / 7.787 # rescale to the reference white (illuminant) xyz_ref_white = get_xyz_coords(illuminant, observer) # cuda if lab.is_cuda: xyz_ref_white = xyz_ref_white.cuda() xyz_ref_white = xyz_ref_white.unsqueeze(2).unsqueeze(2).repeat(1,1,out.shape[2],out.shape[3]) out = out * xyz_ref_white return out def xyz2rgb(xyz): arr = _convert(rgb_from_xyz, xyz) mask = arr > 0.0031308 arrm = arr.clone() arr[mask] = 1.055 * torch.pow(arrm[mask], 1 / 2.4) - 0.055 arrm = arr.clone() arr[~mask] = arrm[~mask] * 12.92 # CLAMP KILLS GRADIENTS # mask_z = arr < 0 # arr[mask_z] = 0 # mask_o = arr > 1 # arr[mask_o] = 1 # torch.clamp(arr, 0, 1, out=arr) return arr def lab2rgb(lab, illuminant="D65", observer="2"): """Lab to RGB color space conversion. Parameters ---------- lab : array_like The image in Lab format, in a 3-D array of shape ``(.., .., 3)``. illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional The name of the illuminant (the function is NOT case sensitive). observer : {"2", "10"}, optional The aperture angle of the observer. Returns ------- out : ndarray The image in RGB format, in a 3-D array of shape ``(.., .., 3)``. Raises ------ ValueError If `lab` is not a 3-D array of shape ``(.., .., 3)``. References ---------- .. [1] https://en.wikipedia.org/wiki/Standard_illuminant Notes ----- This function uses lab2xyz and xyz2rgb. By default Observer= 2A, Illuminant= D65. CIE XYZ tristimulus values x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for a list of supported illuminants. """ return xyz2rgb(lab2xyz(lab, illuminant, observer))