forked from mrq/DL-Art-School
More features for multi-frame-dataset
This commit is contained in:
parent
aeaf185314
commit
57814f18cf
|
@ -40,6 +40,8 @@ def create_dataset(dataset_opt):
|
||||||
from data.full_image_dataset import FullImageDataset as D
|
from data.full_image_dataset import FullImageDataset as D
|
||||||
elif mode == 'single_image_extensible':
|
elif mode == 'single_image_extensible':
|
||||||
from data.single_image_dataset import SingleImageDataset as D
|
from data.single_image_dataset import SingleImageDataset as D
|
||||||
|
elif mode == 'multi_frame_extensible':
|
||||||
|
from data.multi_frame_dataset import MultiFrameDataset as D
|
||||||
elif mode == 'combined':
|
elif mode == 'combined':
|
||||||
from data.combined_dataset import CombinedDataset as D
|
from data.combined_dataset import CombinedDataset as D
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -9,17 +9,20 @@ from io import BytesIO
|
||||||
# options.
|
# options.
|
||||||
class ImageCorruptor:
|
class ImageCorruptor:
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
|
self.fixed_corruptions = opt['fixed_corruptions']
|
||||||
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 2
|
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 2
|
||||||
if self.num_corrupts == 0:
|
if self.num_corrupts == 0:
|
||||||
return
|
return
|
||||||
self.fixed_corruptions = opt['fixed_corruptions']
|
|
||||||
self.random_corruptions = opt['random_corruptions']
|
self.random_corruptions = opt['random_corruptions']
|
||||||
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
|
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
|
||||||
|
|
||||||
def corrupt_images(self, imgs):
|
def corrupt_images(self, imgs):
|
||||||
if self.num_corrupts == 0:
|
if self.num_corrupts == 0 and not self.fixed_corruptions:
|
||||||
return imgs
|
return imgs
|
||||||
|
|
||||||
|
if self.num_corrupts == 0:
|
||||||
|
augmentations = []
|
||||||
|
else:
|
||||||
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
|
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
|
||||||
# Source of entropy, which should be used across all images.
|
# Source of entropy, which should be used across all images.
|
||||||
rand_int_f = random.randint(1, 999999)
|
rand_int_f = random.randint(1, 999999)
|
||||||
|
@ -80,8 +83,14 @@ class ImageCorruptor:
|
||||||
noise_intensity = (rand_int % 4 + 2) / 255.0 # Between 1-4
|
noise_intensity = (rand_int % 4 + 2) / 255.0 # Between 1-4
|
||||||
img += np.random.randn(*img.shape) * noise_intensity
|
img += np.random.randn(*img.shape) * noise_intensity
|
||||||
elif 'jpeg' in aug:
|
elif 'jpeg' in aug:
|
||||||
|
if aug == 'jpeg':
|
||||||
|
lo=10
|
||||||
|
range=20
|
||||||
|
elif aug == 'jpeg-medium':
|
||||||
|
lo=23
|
||||||
|
range=25
|
||||||
# JPEG compression
|
# JPEG compression
|
||||||
qf = (rand_int % 20 + 5) # Between 5-25
|
qf = (rand_int % range + lo)
|
||||||
# cv2's jpeg compression is "odd". It introduces artifacts. Use PIL instead.
|
# cv2's jpeg compression is "odd". It introduces artifacts. Use PIL instead.
|
||||||
img = (img * 255).astype(np.uint8)
|
img = (img * 255).astype(np.uint8)
|
||||||
img = Image.fromarray(img)
|
img = Image.fromarray(img)
|
||||||
|
|
|
@ -68,9 +68,9 @@ if __name__ == '__main__':
|
||||||
'force_multiple': 32,
|
'force_multiple': 32,
|
||||||
'scale': 2,
|
'scale': 2,
|
||||||
'eval': False,
|
'eval': False,
|
||||||
'fixed_corruptions': [],
|
'fixed_corruptions': ['jpeg-medium'],
|
||||||
'random_corruptions': ['color_quantization', 'gaussian_blur', 'motion_blur', 'smooth_blur', 'noise', 'saturation'],
|
'random_corruptions': [],
|
||||||
'num_corrupts_per_image': 1,
|
'num_corrupts_per_image': 0,
|
||||||
'num_frames': 10
|
'num_frames': 10
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ if __name__ == '__main__':
|
||||||
for i in range(100000, len(ds)):
|
for i in range(100000, len(ds)):
|
||||||
import random
|
import random
|
||||||
o = ds[random.randint(0, 1000000)]
|
o = ds[random.randint(0, 1000000)]
|
||||||
k = 'gt_fullsize_ref'
|
k = 'GT'
|
||||||
v = o[k]
|
v = o[k]
|
||||||
if 'path' not in k and 'center' not in k:
|
if 'path' not in k and 'center' not in k:
|
||||||
fr, f, h, w = v.shape
|
fr, f, h, w = v.shape
|
||||||
|
|
Loading…
Reference in New Issue
Block a user