Pass a corruption factor through the dataset into the upsampling network
The intuition is this will help guide the network to make better informed decisions about how it performs upsampling based on how it perceives the underlying content. (I'm giving up on letting networks detect their own quality - I'm not convinced it is actually feasible)
This commit is contained in:
parent
4dd053f694
commit
6c6e82406e
|
@ -1,10 +1,29 @@
|
|||
import random
|
||||
from math import cos, pi
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from data.util import read_img
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
# Feeds a random uniform through a cosine distribution to slightly bias corruptions towards "uncorrupted".
|
||||
# Return is on [0,1] with a bias towards 0.
|
||||
def get_rand():
|
||||
r = random.random()
|
||||
return 1 - cos(r * pi / 2)
|
||||
|
||||
# Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data)
|
||||
'''
|
||||
if __name__ == '__main__':
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
data = np.asarray([get_rand() for _ in range(5000)])
|
||||
plt.plot(data, np.random.uniform(size=(5000,)), 'x')
|
||||
plt.show()
|
||||
'''
|
||||
|
||||
# Performs image corruption on a list of images from a configurable set of corruption
|
||||
# options.
|
||||
class ImageCorruptor:
|
||||
|
@ -16,7 +35,7 @@ class ImageCorruptor:
|
|||
return
|
||||
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
|
||||
|
||||
def corrupt_images(self, imgs):
|
||||
def corrupt_images(self, imgs, return_entropy=False):
|
||||
if self.num_corrupts == 0 and not self.fixed_corruptions:
|
||||
return imgs
|
||||
|
||||
|
@ -24,51 +43,45 @@ class ImageCorruptor:
|
|||
augmentations = []
|
||||
else:
|
||||
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
|
||||
# Source of entropy, which should be used across all images.
|
||||
rand_int_f = random.randint(1, 999999)
|
||||
rand_int_a = random.randint(1, 999999)
|
||||
|
||||
# Sources of entropy
|
||||
corrupted_imgs = []
|
||||
entropy = []
|
||||
applied_augs = augmentations + self.fixed_corruptions
|
||||
for img in imgs:
|
||||
for aug in augmentations:
|
||||
img = self.apply_corruption(img, aug, rand_int_a, applied_augs)
|
||||
r = get_rand()
|
||||
img = self.apply_corruption(img, aug, r, applied_augs)
|
||||
for aug in self.fixed_corruptions:
|
||||
img = self.apply_corruption(img, aug, rand_int_f, applied_augs)
|
||||
r = get_rand()
|
||||
img = self.apply_corruption(img, aug, r, applied_augs)
|
||||
entropy.append(r)
|
||||
corrupted_imgs.append(img)
|
||||
|
||||
return corrupted_imgs
|
||||
if return_entropy:
|
||||
return corrupted_imgs, entropy
|
||||
else:
|
||||
return corrupted_imgs
|
||||
|
||||
def apply_corruption(self, img, aug, rand_int, applied_augmentations):
|
||||
def apply_corruption(self, img, aug, rand_val, applied_augmentations):
|
||||
if 'color_quantization' in aug:
|
||||
# Color quantization
|
||||
quant_div = 2 ** ((rand_int % 3) + 2)
|
||||
quant_div = 2 ** (int(rand_val * 10 / 3) + 2)
|
||||
img = img * 255
|
||||
img = (img // quant_div) * quant_div
|
||||
img = img / 255
|
||||
elif 'gaussian_blur' in aug:
|
||||
# Gaussian Blur
|
||||
if aug == 'gaussian_blur_3':
|
||||
kernel = 3
|
||||
elif aug == 'gaussian_blur_5':
|
||||
kernel = 5
|
||||
else:
|
||||
kernel = 2 * self.blur_scale * (rand_int % 3) + 1
|
||||
img = cv2.GaussianBlur(img, (kernel, kernel), 3)
|
||||
img = cv2.GaussianBlur(img, (0,0), rand_val*1.5)
|
||||
elif 'motion_blur' in aug:
|
||||
# Motion blur
|
||||
intensity = self.blur_scale * (rand_int % 3) + 1
|
||||
angle = (rand_int // 3) % 360
|
||||
intensity = self.blur_scale * rand_val * 3 + 1
|
||||
angle = random.randint(0,360)
|
||||
k = np.zeros((intensity, intensity), dtype=np.float32)
|
||||
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)
|
||||
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((intensity / 2 - 0.5, intensity / 2 - 0.5), angle, 1.0),
|
||||
(intensity, intensity))
|
||||
k = k * (1.0 / np.sum(k))
|
||||
img = cv2.filter2D(img, -1, k)
|
||||
elif 'smooth_blur' in aug:
|
||||
# Smooth blur
|
||||
kernel = 2 * self.blur_scale * (rand_int % 3) + 1
|
||||
img = cv2.blur(img, ksize=(kernel, kernel))
|
||||
elif 'block_noise' in aug:
|
||||
# Large distortion blocks in part of an img, such as is used to mask out a face.
|
||||
pass
|
||||
|
@ -78,7 +91,7 @@ class ImageCorruptor:
|
|||
if 'lq_resampling4x' == aug:
|
||||
scale = 4
|
||||
interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
|
||||
mode = rand_int % len(interpolation_modes)
|
||||
mode = random.randint(0,4) % len(interpolation_modes)
|
||||
# Downsample first, then upsample using the random mode.
|
||||
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST)
|
||||
img = cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=mode)
|
||||
|
@ -96,7 +109,7 @@ class ImageCorruptor:
|
|||
if 'noise-5' == aug:
|
||||
noise_intensity = 5 / 255.0
|
||||
else:
|
||||
noise_intensity = (rand_int % 4 + 2) / 255.0 # Between 1-4
|
||||
noise_intensity = (rand_val*4 + 2) / 255.0
|
||||
img += np.random.randn(*img.shape) * noise_intensity
|
||||
elif 'jpeg' in aug:
|
||||
if 'noise' not in applied_augmentations and 'noise-5' not in applied_augmentations:
|
||||
|
@ -118,7 +131,7 @@ class ImageCorruptor:
|
|||
else:
|
||||
raise NotImplementedError("specified jpeg corruption doesn't exist")
|
||||
# JPEG compression
|
||||
qf = (rand_int % range + lo)
|
||||
qf = (int((1-rand_val)*range) + lo)
|
||||
# Use PIL to perform a mock compression to a data buffer, then swap back to cv2.
|
||||
img = (img * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img)
|
||||
|
@ -129,7 +142,7 @@ class ImageCorruptor:
|
|||
img = read_img("buffer", jpeg_img_bytes, rgb=True)
|
||||
elif 'saturation' in aug:
|
||||
# Lightening / saturation
|
||||
saturation = float(rand_int % 10) * .03
|
||||
saturation = rand_val * .3
|
||||
img = np.clip(img + saturation, a_max=1, a_min=0)
|
||||
elif 'none' not in aug:
|
||||
raise NotImplementedError("Augmentation doesn't exist")
|
||||
|
|
|
@ -112,14 +112,14 @@ class ImageFolderDataset:
|
|||
local_scale = local_scale // special_factor
|
||||
else:
|
||||
hs = [h.copy() for h in hs]
|
||||
hs = self.corruptor.corrupt_images(hs)
|
||||
hs, ent = self.corruptor.corrupt_images(hs, return_entropy=True)
|
||||
for hq in hs:
|
||||
h, w, _ = hq.shape
|
||||
ls.append(cv2.resize(hq, (h // local_scale, w // local_scale), interpolation=cv2.INTER_AREA))
|
||||
# Corrupt the LQ image (only in eval mode)
|
||||
if not self.corrupt_before_downsize:
|
||||
ls = self.corruptor.corrupt_images(ls)
|
||||
return ls
|
||||
ls, ent = self.corruptor.corrupt_images(ls, return_entropy=True)
|
||||
return ls, ent
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
@ -184,9 +184,10 @@ class ImageFolderDataset:
|
|||
out_dict['alt_hq'] = alt_hq
|
||||
|
||||
if not self.skip_lq:
|
||||
lqs = self.synthesize_lq(for_lq)
|
||||
lqs, ent = self.synthesize_lq(for_lq)
|
||||
ls = lqs[0]
|
||||
out_dict['lq'] = torch.from_numpy(np.ascontiguousarray(np.transpose(ls, (2, 0, 1)))).float()
|
||||
out_dict['corruption_entropy'] = torch.tensor(ent)
|
||||
if len(lqs) > 1:
|
||||
alt_lq = lqs[1]
|
||||
out_dict['alt_lq'] = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_lq, (2, 0, 1)))).float()
|
||||
|
@ -215,25 +216,25 @@ class ImageFolderDataset:
|
|||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'name': 'amalgam',
|
||||
'paths': ['E:\\4k6k\\datasets\\ns_images\\256_unsupervised'],
|
||||
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||
'weights': [1],
|
||||
'target_size': 256,
|
||||
'force_multiple': 1,
|
||||
'scale': 2,
|
||||
'corrupt_before_downsize': True,
|
||||
'fetch_alt_image': True,
|
||||
'fetch_alt_image': False,
|
||||
'disable_flip': True,
|
||||
'fixed_corruptions': [ 'jpeg-broad' ],
|
||||
'fixed_corruptions': [ 'jpeg-broad', 'gaussian_blur' ],
|
||||
'num_corrupts_per_image': 0,
|
||||
'corruption_blur_scale': 0
|
||||
}
|
||||
|
||||
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=2)
|
||||
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0)
|
||||
import os
|
||||
output_path = 'E:\\4k6k\\datasets\\ns_images\\128_unsupervised'
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
for i, d in tqdm(enumerate(ds)):
|
||||
lq = d['lq']
|
||||
torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png')
|
||||
#torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png')
|
||||
if i >= 200000:
|
||||
break
|
|
@ -660,12 +660,18 @@ class SuperResModel(UNetModel):
|
|||
Expects an extra kwarg `low_res` to condition on a low-resolution image.
|
||||
"""
|
||||
|
||||
def __init__(self, image_size, in_channels, *args, **kwargs):
|
||||
super().__init__(image_size, in_channels * 2, *args, **kwargs)
|
||||
def __init__(self, image_size, in_channels, num_corruptions=0, *args, **kwargs):
|
||||
self.num_corruptions = 0
|
||||
super().__init__(image_size, in_channels * 2 + num_corruptions, *args, **kwargs)
|
||||
|
||||
def forward(self, x, timesteps, low_res=None, **kwargs):
|
||||
_, _, new_height, new_width = x.shape
|
||||
def forward(self, x, timesteps, low_res=None, corruption_factor=None, **kwargs):
|
||||
b, _, new_height, new_width = x.shape
|
||||
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
|
||||
if corruption_factor is not None:
|
||||
corruption_factor = corruption_factor.view(b, -1, 1, 1).repeat(1, 1, new_height, new_width)
|
||||
else:
|
||||
corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device)
|
||||
upsampled = torch.cat([upsampled, corruption_factor], dim=1)
|
||||
x = th.cat([x, upsampled], dim=1)
|
||||
res = super().forward(x, timesteps, **kwargs)
|
||||
return res
|
||||
|
|
Loading…
Reference in New Issue
Block a user