DL-Art-School/codes/data/images/image_corruptor.py
2022-03-16 12:05:56 -06:00

210 lines
8.0 KiB
Python

import functools
import random
from math import cos, pi
import cv2
import kornia
import numpy as np
import torch
from kornia.augmentation import ColorJitter
from data.util import read_img
from PIL import Image
from io import BytesIO
# Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data)
from utils.util import opt_get
'''
if __name__ == '__main__':
import numpy as np
import matplotlib.pyplot as plt
data = np.asarray([get_rand() for _ in range(5000)])
plt.plot(data, np.random.uniform(size=(5000,)), 'x')
plt.show()
'''
def kornia_color_jitter_numpy(img, setting):
if setting * 255 > 1:
# I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format.
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
img = ColorJitter(setting, setting, setting, setting)(img)
img = img.squeeze(0).permute(1,2,0).numpy()
return img
# Performs image corruption on a list of images from a configurable set of corruption
# options.
class ImageCorruptor:
def __init__(self, opt):
self.opt = opt
self.reset_random()
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
self.cosine_bias = opt_get(opt, ['cosine_bias'], True)
if self.num_corrupts == 0:
return
else:
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
def reset_random(self):
if 'random_seed' in self.opt.keys():
self.rand = random.Random(self.opt['random_seed'])
else:
self.rand = random.Random()
# Feeds a random uniform through a cosine distribution to slightly bias corruptions towards "uncorrupted".
# Return is on [0,1] with a bias towards 0.
def get_rand(self):
r = self.rand.random()
if self.cosine_bias:
return 1 - cos(r * pi / 2)
else:
return r
def corrupt_images(self, imgs, return_entropy=False):
if self.num_corrupts == 0 and not self.fixed_corruptions:
if return_entropy:
return imgs, []
else:
return imgs
if self.num_corrupts == 0:
augmentations = []
else:
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
# Sources of entropy
corrupted_imgs = []
entropy = []
undo_fns = []
applied_augs = augmentations + self.fixed_corruptions
for img in imgs:
for aug in augmentations:
r = self.get_rand()
img, undo_fn = self.apply_corruption(img, aug, r, applied_augs)
if undo_fn is not None:
undo_fns.append(undo_fn)
for aug in self.fixed_corruptions:
r = self.get_rand()
img, undo_fn = self.apply_corruption(img, aug, r, applied_augs)
entropy.append(r)
if undo_fn is not None:
undo_fns.append(undo_fn)
# Apply undo_fns after all corruptions are finished, in same order.
for ufn in undo_fns:
img = ufn(img)
corrupted_imgs.append(img)
if return_entropy:
return corrupted_imgs, entropy
else:
return corrupted_imgs
def apply_corruption(self, img, aug, rand_val, applied_augmentations):
undo_fn = None
if 'color_quantization' in aug:
# Color quantization
quant_div = 2 ** (int(rand_val * 10 / 3) + 2)
img = img * 255
img = (img // quant_div) * quant_div
img = img / 255
elif 'color_jitter' in aug:
lo_end = 0
hi_end = .2
setting = rand_val * (hi_end - lo_end) + lo_end
img = kornia_color_jitter_numpy(img, setting)
elif 'gaussian_blur' in aug:
img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5)
elif 'motion_blur' in aug:
# Motion blur
intensity = self.blur_scale*rand_val * 3 + 1
angle = random.randint(0,360)
k = np.zeros((intensity, intensity), dtype=np.float32)
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((intensity / 2 - 0.5, intensity / 2 - 0.5), angle, 1.0),
(intensity, intensity))
k = k * (1.0 / np.sum(k))
img = cv2.filter2D(img, -1, k)
elif 'block_noise' in aug:
# Large distortion blocks in part of an img, such as is used to mask out a face.
pass
elif 'lq_resampling' in aug:
# Random mode interpolation HR->LR->HR
if 'lq_resampling4x' == aug:
scale = 4
else:
if rand_val < .3:
scale = 1
elif rand_val < .7:
scale = 2
else:
scale = 4
if scale > 1:
interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
mode = random.randint(0,4) % len(interpolation_modes)
# Downsample first, then upsample using the random mode.
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=mode)
def lq_resampling_undo_fn(scale, img):
return cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=cv2.INTER_LINEAR)
undo_fn = functools.partial(lq_resampling_undo_fn, scale)
elif 'color_shift' in aug:
# Color shift
pass
elif 'interlacing' in aug:
# Interlacing distortion
pass
elif 'chromatic_aberration' in aug:
# Chromatic aberration
pass
elif 'noise' in aug:
# Random noise
if 'noise-5' == aug:
noise_intensity = 5 / 255.0
else:
noise_intensity = (rand_val*6) / 255.0
img += np.random.rand(*img.shape) * noise_intensity
elif 'jpeg' in aug:
if 'noise' not in applied_augmentations and 'noise-5' not in applied_augmentations:
if aug == 'jpeg':
lo=10
range=20
elif aug == 'jpeg-low':
lo=15
range=10
elif aug == 'jpeg-medium':
lo=23
range=25
elif aug == 'jpeg-broad':
lo=15
range=60
elif aug == 'jpeg-normal':
lo=47
range=35
else:
raise NotImplementedError("specified jpeg corruption doesn't exist")
# JPEG compression
qf = (int((1-rand_val)*range) + lo)
# Use PIL to perform a mock compression to a data buffer, then swap back to cv2.
img = (img * 255).astype(np.uint8)
img = Image.fromarray(img)
buffer = BytesIO()
img.save(buffer, "JPEG", quality=qf, optimize=True)
buffer.seek(0)
jpeg_img_bytes = np.asarray(bytearray(buffer.read()), dtype="uint8")
img = read_img("buffer", jpeg_img_bytes, rgb=True)
elif 'saturation' in aug:
# Lightening / saturation
saturation = rand_val * .3
img = np.clip(img + saturation, a_max=1, a_min=0)
elif 'greyscale' in aug:
img = np.tile(np.mean(img, axis=2, keepdims=True), [1,1,3])
elif 'none' not in aug:
raise NotImplementedError("Augmentation doesn't exist")
return img, undo_fn