More features for multi-frame-dataset

This commit is contained in:
James Betker 2020-09-28 14:26:15 -06:00
parent aeaf185314
commit 57814f18cf
3 changed files with 19 additions and 8 deletions

View File

@ -40,6 +40,8 @@ def create_dataset(dataset_opt):
from data.full_image_dataset import FullImageDataset as D
elif mode == 'single_image_extensible':
from data.single_image_dataset import SingleImageDataset as D
elif mode == 'multi_frame_extensible':
from data.multi_frame_dataset import MultiFrameDataset as D
elif mode == 'combined':
from data.combined_dataset import CombinedDataset as D
else:

View File

@ -9,18 +9,21 @@ from io import BytesIO
# options.
class ImageCorruptor:
def __init__(self, opt):
self.fixed_corruptions = opt['fixed_corruptions']
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 2
if self.num_corrupts == 0:
return
self.fixed_corruptions = opt['fixed_corruptions']
self.random_corruptions = opt['random_corruptions']
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
def corrupt_images(self, imgs):
if self.num_corrupts == 0:
if self.num_corrupts == 0 and not self.fixed_corruptions:
return imgs
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
if self.num_corrupts == 0:
augmentations = []
else:
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
# Source of entropy, which should be used across all images.
rand_int_f = random.randint(1, 999999)
rand_int_a = random.randint(1, 999999)
@ -80,8 +83,14 @@ class ImageCorruptor:
noise_intensity = (rand_int % 4 + 2) / 255.0 # Between 1-4
img += np.random.randn(*img.shape) * noise_intensity
elif 'jpeg' in aug:
if aug == 'jpeg':
lo=10
range=20
elif aug == 'jpeg-medium':
lo=23
range=25
# JPEG compression
qf = (rand_int % 20 + 5) # Between 5-25
qf = (rand_int % range + lo)
# cv2's jpeg compression is "odd". It introduces artifacts. Use PIL instead.
img = (img * 255).astype(np.uint8)
img = Image.fromarray(img)

View File

@ -68,9 +68,9 @@ if __name__ == '__main__':
'force_multiple': 32,
'scale': 2,
'eval': False,
'fixed_corruptions': [],
'random_corruptions': ['color_quantization', 'gaussian_blur', 'motion_blur', 'smooth_blur', 'noise', 'saturation'],
'num_corrupts_per_image': 1,
'fixed_corruptions': ['jpeg-medium'],
'random_corruptions': [],
'num_corrupts_per_image': 0,
'num_frames': 10
}
@ -80,7 +80,7 @@ if __name__ == '__main__':
for i in range(100000, len(ds)):
import random
o = ds[random.randint(0, 1000000)]
k = 'gt_fullsize_ref'
k = 'GT'
v = o[k]
if 'path' not in k and 'center' not in k:
fr, f, h, w = v.shape