forked from mrq/DL-Art-School
Allow hq color jittering and corruptions that are not included in the corruption factor
This commit is contained in:
parent
6fd16ea9c8
commit
afa41f1804
|
@ -25,6 +25,16 @@ if __name__ == '__main__':
|
|||
plt.show()
|
||||
'''
|
||||
|
||||
|
||||
def kornia_color_jitter_numpy(img, setting):
|
||||
if setting * 255 > 1:
|
||||
# I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format.
|
||||
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
|
||||
img = ColorJitter(setting, setting, setting, setting)(img)
|
||||
img = img.squeeze(0).permute(1,2,0).numpy()
|
||||
return img
|
||||
|
||||
|
||||
# Performs image corruption on a list of images from a configurable set of corruption
|
||||
# options.
|
||||
class ImageCorruptor:
|
||||
|
@ -107,11 +117,7 @@ class ImageCorruptor:
|
|||
lo_end = 0
|
||||
hi_end = .2
|
||||
setting = rand_val * (hi_end - lo_end) + lo_end
|
||||
if setting * 255 > 1:
|
||||
# I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format.
|
||||
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
|
||||
img = ColorJitter(setting, setting, setting, setting)(img)
|
||||
img = img.squeeze(0).permute(1,2,0).numpy()
|
||||
img = kornia_color_jitter_numpy(img, setting)
|
||||
elif 'gaussian_blur' in aug:
|
||||
img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5)
|
||||
elif 'motion_blur' in aug:
|
||||
|
|
|
@ -17,7 +17,7 @@ from tqdm import tqdm
|
|||
|
||||
from data import util
|
||||
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
|
||||
from data.image_corruptor import ImageCorruptor
|
||||
from data.image_corruptor import ImageCorruptor, kornia_color_jitter_numpy
|
||||
from data.image_label_parser import VsNetImageLabeler
|
||||
from utils.util import opt_get
|
||||
|
||||
|
@ -50,6 +50,7 @@ class ImageFolderDataset:
|
|||
self.rgb_n1_to_1 = opt_get(opt, ['rgb_n1_to_1'], False)
|
||||
self.force_square = opt_get(opt, ['force_square'], True)
|
||||
self.fixed_parameters = {k: torch.tensor(v) for k, v in opt_get(opt, ['fixed_parameters'], {}).items()}
|
||||
self.all_image_color_jitter = opt_get(opt, ['all_image_color_jitter'], 0)
|
||||
if 'normalize' in opt.keys():
|
||||
if opt['normalize'] == 'stylegan2_norm':
|
||||
self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
|
@ -155,6 +156,10 @@ class ImageFolderDataset:
|
|||
dim = min(h, w)
|
||||
hq = hq[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
|
||||
|
||||
# Perform color jittering on the HQ image if specified. The given value should be between [0,1].
|
||||
if self.all_image_color_jitter > 0:
|
||||
hq = kornia_color_jitter_numpy(hq, self.all_image_color_jitter)
|
||||
|
||||
if self.labeler:
|
||||
assert hq.shape[0] == hq.shape[1] # This just has not been accomodated yet.
|
||||
dim = hq.shape[0]
|
||||
|
@ -273,7 +278,8 @@ if __name__ == '__main__':
|
|||
'disable_flip': True,
|
||||
'fixed_corruptions': ['lq_resampling', 'jpeg-medium', 'gaussian_blur', 'noise', 'color_jitter'],
|
||||
'num_corrupts_per_image': 0,
|
||||
'corruption_blur_scale': 1
|
||||
'corruption_blur_scale': 1,
|
||||
'all_image_color_jitter': .1,
|
||||
}
|
||||
|
||||
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0, batch_size=64)
|
||||
|
|
|
@ -661,13 +661,18 @@ class SuperResModel(UNetModel):
|
|||
"""
|
||||
|
||||
def __init__(self, image_size, in_channels, num_corruptions=0, *args, **kwargs):
|
||||
self.num_corruptions = 0
|
||||
self.num_corruptions = num_corruptions
|
||||
super().__init__(image_size, in_channels * 2 + num_corruptions, *args, **kwargs)
|
||||
|
||||
def forward(self, x, timesteps, low_res=None, corruption_factor=None, **kwargs):
|
||||
b, _, new_height, new_width = x.shape
|
||||
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
|
||||
if corruption_factor is not None:
|
||||
if corruption_factor.shape[1] != self.num_corruptions:
|
||||
if not hasattr(self, '_corruption_factor_warning'):
|
||||
print(f"Warning! Dataloader gave us {corruption_factor.shape[1]} dim but we are only processing {self.num_corruptions}. The last n corruptions will be truncated.")
|
||||
self._corruption_factor_warning = True
|
||||
corruption_factor = corruption_factor[:, :self.num_corruptions]
|
||||
corruption_factor = corruption_factor.view(b, -1, 1, 1).repeat(1, 1, new_height, new_width)
|
||||
else:
|
||||
corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device)
|
||||
|
|
Loading…
Reference in New Issue
Block a user