forked from mrq/DL-Art-School
More features for multi-frame-dataset
This commit is contained in:
parent
aeaf185314
commit
57814f18cf
|
@ -40,6 +40,8 @@ def create_dataset(dataset_opt):
|
|||
from data.full_image_dataset import FullImageDataset as D
|
||||
elif mode == 'single_image_extensible':
|
||||
from data.single_image_dataset import SingleImageDataset as D
|
||||
elif mode == 'multi_frame_extensible':
|
||||
from data.multi_frame_dataset import MultiFrameDataset as D
|
||||
elif mode == 'combined':
|
||||
from data.combined_dataset import CombinedDataset as D
|
||||
else:
|
||||
|
|
|
@ -9,18 +9,21 @@ from io import BytesIO
|
|||
# options.
|
||||
class ImageCorruptor:
|
||||
def __init__(self, opt):
|
||||
self.fixed_corruptions = opt['fixed_corruptions']
|
||||
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 2
|
||||
if self.num_corrupts == 0:
|
||||
return
|
||||
self.fixed_corruptions = opt['fixed_corruptions']
|
||||
self.random_corruptions = opt['random_corruptions']
|
||||
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
|
||||
|
||||
def corrupt_images(self, imgs):
|
||||
if self.num_corrupts == 0:
|
||||
if self.num_corrupts == 0 and not self.fixed_corruptions:
|
||||
return imgs
|
||||
|
||||
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
|
||||
if self.num_corrupts == 0:
|
||||
augmentations = []
|
||||
else:
|
||||
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
|
||||
# Source of entropy, which should be used across all images.
|
||||
rand_int_f = random.randint(1, 999999)
|
||||
rand_int_a = random.randint(1, 999999)
|
||||
|
@ -80,8 +83,14 @@ class ImageCorruptor:
|
|||
noise_intensity = (rand_int % 4 + 2) / 255.0 # Between 1-4
|
||||
img += np.random.randn(*img.shape) * noise_intensity
|
||||
elif 'jpeg' in aug:
|
||||
if aug == 'jpeg':
|
||||
lo=10
|
||||
range=20
|
||||
elif aug == 'jpeg-medium':
|
||||
lo=23
|
||||
range=25
|
||||
# JPEG compression
|
||||
qf = (rand_int % 20 + 5) # Between 5-25
|
||||
qf = (rand_int % range + lo)
|
||||
# cv2's jpeg compression is "odd". It introduces artifacts. Use PIL instead.
|
||||
img = (img * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img)
|
||||
|
|
|
@ -68,9 +68,9 @@ if __name__ == '__main__':
|
|||
'force_multiple': 32,
|
||||
'scale': 2,
|
||||
'eval': False,
|
||||
'fixed_corruptions': [],
|
||||
'random_corruptions': ['color_quantization', 'gaussian_blur', 'motion_blur', 'smooth_blur', 'noise', 'saturation'],
|
||||
'num_corrupts_per_image': 1,
|
||||
'fixed_corruptions': ['jpeg-medium'],
|
||||
'random_corruptions': [],
|
||||
'num_corrupts_per_image': 0,
|
||||
'num_frames': 10
|
||||
}
|
||||
|
||||
|
@ -80,7 +80,7 @@ if __name__ == '__main__':
|
|||
for i in range(100000, len(ds)):
|
||||
import random
|
||||
o = ds[random.randint(0, 1000000)]
|
||||
k = 'gt_fullsize_ref'
|
||||
k = 'GT'
|
||||
v = o[k]
|
||||
if 'path' not in k and 'center' not in k:
|
||||
fr, f, h, w = v.shape
|
||||
|
|
Loading…
Reference in New Issue
Block a user