Pass a corruption factor through the dataset into the upsampling network

The intuition is this will help guide the network to make better informed decisions
about how it performs upsampling based on how it perceives the underlying content.

(I'm giving up on letting networks detect their own quality - I'm not convinced it is
actually feasible)
This commit is contained in:
James Betker 2021-06-07 09:13:54 -06:00
parent 4dd053f694
commit 6c6e82406e
3 changed files with 60 additions and 40 deletions

View File

@ -1,10 +1,29 @@
import random
from math import cos, pi
import cv2
import numpy as np
from data.util import read_img
from PIL import Image
from io import BytesIO
# Feeds a random uniform through a cosine distribution to slightly bias corruptions towards "uncorrupted".
# Return is on [0,1] with a bias towards 0.
def get_rand():
r = random.random()
return 1 - cos(r * pi / 2)
# Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data)
'''
if __name__ == '__main__':
import numpy as np
import matplotlib.pyplot as plt
data = np.asarray([get_rand() for _ in range(5000)])
plt.plot(data, np.random.uniform(size=(5000,)), 'x')
plt.show()
'''
# Performs image corruption on a list of images from a configurable set of corruption
# options.
class ImageCorruptor:
@ -16,7 +35,7 @@ class ImageCorruptor:
return
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
def corrupt_images(self, imgs):
def corrupt_images(self, imgs, return_entropy=False):
if self.num_corrupts == 0 and not self.fixed_corruptions:
return imgs
@ -24,51 +43,45 @@ class ImageCorruptor:
augmentations = []
else:
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
# Source of entropy, which should be used across all images.
rand_int_f = random.randint(1, 999999)
rand_int_a = random.randint(1, 999999)
# Sources of entropy
corrupted_imgs = []
entropy = []
applied_augs = augmentations + self.fixed_corruptions
for img in imgs:
for aug in augmentations:
img = self.apply_corruption(img, aug, rand_int_a, applied_augs)
r = get_rand()
img = self.apply_corruption(img, aug, r, applied_augs)
for aug in self.fixed_corruptions:
img = self.apply_corruption(img, aug, rand_int_f, applied_augs)
r = get_rand()
img = self.apply_corruption(img, aug, r, applied_augs)
entropy.append(r)
corrupted_imgs.append(img)
return corrupted_imgs
if return_entropy:
return corrupted_imgs, entropy
else:
return corrupted_imgs
def apply_corruption(self, img, aug, rand_int, applied_augmentations):
def apply_corruption(self, img, aug, rand_val, applied_augmentations):
if 'color_quantization' in aug:
# Color quantization
quant_div = 2 ** ((rand_int % 3) + 2)
quant_div = 2 ** (int(rand_val * 10 / 3) + 2)
img = img * 255
img = (img // quant_div) * quant_div
img = img / 255
elif 'gaussian_blur' in aug:
# Gaussian Blur
if aug == 'gaussian_blur_3':
kernel = 3
elif aug == 'gaussian_blur_5':
kernel = 5
else:
kernel = 2 * self.blur_scale * (rand_int % 3) + 1
img = cv2.GaussianBlur(img, (kernel, kernel), 3)
img = cv2.GaussianBlur(img, (0,0), rand_val*1.5)
elif 'motion_blur' in aug:
# Motion blur
intensity = self.blur_scale * (rand_int % 3) + 1
angle = (rand_int // 3) % 360
intensity = self.blur_scale * rand_val * 3 + 1
angle = random.randint(0,360)
k = np.zeros((intensity, intensity), dtype=np.float32)
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((intensity / 2 - 0.5, intensity / 2 - 0.5), angle, 1.0),
(intensity, intensity))
k = k * (1.0 / np.sum(k))
img = cv2.filter2D(img, -1, k)
elif 'smooth_blur' in aug:
# Smooth blur
kernel = 2 * self.blur_scale * (rand_int % 3) + 1
img = cv2.blur(img, ksize=(kernel, kernel))
elif 'block_noise' in aug:
# Large distortion blocks in part of an img, such as is used to mask out a face.
pass
@ -78,7 +91,7 @@ class ImageCorruptor:
if 'lq_resampling4x' == aug:
scale = 4
interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
mode = rand_int % len(interpolation_modes)
mode = random.randint(0,4) % len(interpolation_modes)
# Downsample first, then upsample using the random mode.
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST)
img = cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=mode)
@ -96,7 +109,7 @@ class ImageCorruptor:
if 'noise-5' == aug:
noise_intensity = 5 / 255.0
else:
noise_intensity = (rand_int % 4 + 2) / 255.0 # Between 1-4
noise_intensity = (rand_val*4 + 2) / 255.0
img += np.random.randn(*img.shape) * noise_intensity
elif 'jpeg' in aug:
if 'noise' not in applied_augmentations and 'noise-5' not in applied_augmentations:
@ -118,7 +131,7 @@ class ImageCorruptor:
else:
raise NotImplementedError("specified jpeg corruption doesn't exist")
# JPEG compression
qf = (rand_int % range + lo)
qf = (int((1-rand_val)*range) + lo)
# Use PIL to perform a mock compression to a data buffer, then swap back to cv2.
img = (img * 255).astype(np.uint8)
img = Image.fromarray(img)
@ -129,7 +142,7 @@ class ImageCorruptor:
img = read_img("buffer", jpeg_img_bytes, rgb=True)
elif 'saturation' in aug:
# Lightening / saturation
saturation = float(rand_int % 10) * .03
saturation = rand_val * .3
img = np.clip(img + saturation, a_max=1, a_min=0)
elif 'none' not in aug:
raise NotImplementedError("Augmentation doesn't exist")

View File

@ -112,14 +112,14 @@ class ImageFolderDataset:
local_scale = local_scale // special_factor
else:
hs = [h.copy() for h in hs]
hs = self.corruptor.corrupt_images(hs)
hs, ent = self.corruptor.corrupt_images(hs, return_entropy=True)
for hq in hs:
h, w, _ = hq.shape
ls.append(cv2.resize(hq, (h // local_scale, w // local_scale), interpolation=cv2.INTER_AREA))
# Corrupt the LQ image (only in eval mode)
if not self.corrupt_before_downsize:
ls = self.corruptor.corrupt_images(ls)
return ls
ls, ent = self.corruptor.corrupt_images(ls, return_entropy=True)
return ls, ent
def __len__(self):
return self.len
@ -184,9 +184,10 @@ class ImageFolderDataset:
out_dict['alt_hq'] = alt_hq
if not self.skip_lq:
lqs = self.synthesize_lq(for_lq)
lqs, ent = self.synthesize_lq(for_lq)
ls = lqs[0]
out_dict['lq'] = torch.from_numpy(np.ascontiguousarray(np.transpose(ls, (2, 0, 1)))).float()
out_dict['corruption_entropy'] = torch.tensor(ent)
if len(lqs) > 1:
alt_lq = lqs[1]
out_dict['alt_lq'] = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_lq, (2, 0, 1)))).float()
@ -215,25 +216,25 @@ class ImageFolderDataset:
if __name__ == '__main__':
opt = {
'name': 'amalgam',
'paths': ['E:\\4k6k\\datasets\\ns_images\\256_unsupervised'],
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
'weights': [1],
'target_size': 256,
'force_multiple': 1,
'scale': 2,
'corrupt_before_downsize': True,
'fetch_alt_image': True,
'fetch_alt_image': False,
'disable_flip': True,
'fixed_corruptions': [ 'jpeg-broad' ],
'fixed_corruptions': [ 'jpeg-broad', 'gaussian_blur' ],
'num_corrupts_per_image': 0,
'corruption_blur_scale': 0
}
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=2)
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0)
import os
output_path = 'E:\\4k6k\\datasets\\ns_images\\128_unsupervised'
os.makedirs(output_path, exist_ok=True)
for i, d in tqdm(enumerate(ds)):
lq = d['lq']
torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png')
#torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png')
if i >= 200000:
break

View File

@ -660,12 +660,18 @@ class SuperResModel(UNetModel):
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, image_size, in_channels, *args, **kwargs):
super().__init__(image_size, in_channels * 2, *args, **kwargs)
def __init__(self, image_size, in_channels, num_corruptions=0, *args, **kwargs):
self.num_corruptions = 0
super().__init__(image_size, in_channels * 2 + num_corruptions, *args, **kwargs)
def forward(self, x, timesteps, low_res=None, **kwargs):
_, _, new_height, new_width = x.shape
def forward(self, x, timesteps, low_res=None, corruption_factor=None, **kwargs):
b, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
if corruption_factor is not None:
corruption_factor = corruption_factor.view(b, -1, 1, 1).repeat(1, 1, new_height, new_width)
else:
corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device)
upsampled = torch.cat([upsampled, corruption_factor], dim=1)
x = th.cat([x, upsampled], dim=1)
res = super().forward(x, timesteps, **kwargs)
return res