Allow hq color jittering and corruptions that are not included in the corruption factor

This commit is contained in:
James Betker 2021-06-30 09:44:46 -06:00
parent 6fd16ea9c8
commit afa41f1804
3 changed files with 25 additions and 8 deletions

View File

@ -25,6 +25,16 @@ if __name__ == '__main__':
plt.show()
'''
def kornia_color_jitter_numpy(img, setting):
if setting * 255 > 1:
# I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format.
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
img = ColorJitter(setting, setting, setting, setting)(img)
img = img.squeeze(0).permute(1,2,0).numpy()
return img
# Performs image corruption on a list of images from a configurable set of corruption
# options.
class ImageCorruptor:
@ -107,11 +117,7 @@ class ImageCorruptor:
lo_end = 0
hi_end = .2
setting = rand_val * (hi_end - lo_end) + lo_end
if setting * 255 > 1:
# I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format.
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
img = ColorJitter(setting, setting, setting, setting)(img)
img = img.squeeze(0).permute(1,2,0).numpy()
img = kornia_color_jitter_numpy(img, setting)
elif 'gaussian_blur' in aug:
img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5)
elif 'motion_blur' in aug:

View File

@ -17,7 +17,7 @@ from tqdm import tqdm
from data import util
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
from data.image_corruptor import ImageCorruptor
from data.image_corruptor import ImageCorruptor, kornia_color_jitter_numpy
from data.image_label_parser import VsNetImageLabeler
from utils.util import opt_get
@ -50,6 +50,7 @@ class ImageFolderDataset:
self.rgb_n1_to_1 = opt_get(opt, ['rgb_n1_to_1'], False)
self.force_square = opt_get(opt, ['force_square'], True)
self.fixed_parameters = {k: torch.tensor(v) for k, v in opt_get(opt, ['fixed_parameters'], {}).items()}
self.all_image_color_jitter = opt_get(opt, ['all_image_color_jitter'], 0)
if 'normalize' in opt.keys():
if opt['normalize'] == 'stylegan2_norm':
self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
@ -155,6 +156,10 @@ class ImageFolderDataset:
dim = min(h, w)
hq = hq[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
# Perform color jittering on the HQ image if specified. The given value should be between [0,1].
if self.all_image_color_jitter > 0:
hq = kornia_color_jitter_numpy(hq, self.all_image_color_jitter)
if self.labeler:
assert hq.shape[0] == hq.shape[1] # This just has not been accomodated yet.
dim = hq.shape[0]
@ -273,7 +278,8 @@ if __name__ == '__main__':
'disable_flip': True,
'fixed_corruptions': ['lq_resampling', 'jpeg-medium', 'gaussian_blur', 'noise', 'color_jitter'],
'num_corrupts_per_image': 0,
'corruption_blur_scale': 1
'corruption_blur_scale': 1,
'all_image_color_jitter': .1,
}
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0, batch_size=64)

View File

@ -661,13 +661,18 @@ class SuperResModel(UNetModel):
"""
def __init__(self, image_size, in_channels, num_corruptions=0, *args, **kwargs):
self.num_corruptions = 0
self.num_corruptions = num_corruptions
super().__init__(image_size, in_channels * 2 + num_corruptions, *args, **kwargs)
def forward(self, x, timesteps, low_res=None, corruption_factor=None, **kwargs):
b, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
if corruption_factor is not None:
if corruption_factor.shape[1] != self.num_corruptions:
if not hasattr(self, '_corruption_factor_warning'):
print(f"Warning! Dataloader gave us {corruption_factor.shape[1]} dim but we are only processing {self.num_corruptions}. The last n corruptions will be truncated.")
self._corruption_factor_warning = True
corruption_factor = corruption_factor[:, :self.num_corruptions]
corruption_factor = corruption_factor.view(b, -1, 1, 1).repeat(1, 1, new_height, new_width)
else:
corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device)