Add meta-anomaly detection, colorjitter augmentation

This commit is contained in:
James Betker 2021-06-29 13:41:55 -06:00
parent 46e9f62be0
commit 6fd16ea9c8
6 changed files with 86 additions and 20 deletions

View File

@ -1,8 +1,13 @@
import functools
import random
from math import cos, pi
import cv2
import kornia
import numpy as np
import torch
from kornia.augmentation import ColorJitter
from data.util import read_img
from PIL import Image
from io import BytesIO
@ -65,29 +70,48 @@ class ImageCorruptor:
# Sources of entropy
corrupted_imgs = []
entropy = []
undo_fns = []
applied_augs = augmentations + self.fixed_corruptions
for img in imgs:
for aug in augmentations:
r = self.get_rand()
img = self.apply_corruption(img, aug, r, applied_augs)
img, undo_fn = self.apply_corruption(img, aug, r, applied_augs)
if undo_fn is not None:
undo_fns.append(undo_fn)
for aug in self.fixed_corruptions:
r = self.get_rand()
img = self.apply_corruption(img, aug, r, applied_augs)
img, undo_fn = self.apply_corruption(img, aug, r, applied_augs)
entropy.append(r)
if undo_fn is not None:
undo_fns.append(undo_fn)
# Apply undo_fns after all corruptions are finished, in same order.
for ufn in undo_fns:
img = ufn(img)
corrupted_imgs.append(img)
if return_entropy:
return corrupted_imgs, entropy
else:
return corrupted_imgs
def apply_corruption(self, img, aug, rand_val, applied_augmentations):
undo_fn = None
if 'color_quantization' in aug:
# Color quantization
quant_div = 2 ** (int(rand_val * 10 / 3) + 2)
img = img * 255
img = (img // quant_div) * quant_div
img = img / 255
elif 'color_jitter' in aug:
lo_end = 0
hi_end = .2
setting = rand_val * (hi_end - lo_end) + lo_end
if setting * 255 > 1:
# I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format.
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
img = ColorJitter(setting, setting, setting, setting)(img)
img = img.squeeze(0).permute(1,2,0).numpy()
elif 'gaussian_blur' in aug:
img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5)
elif 'motion_blur' in aug:
@ -105,14 +129,23 @@ class ImageCorruptor:
pass
elif 'lq_resampling' in aug:
# Random mode interpolation HR->LR->HR
scale = 2
if 'lq_resampling4x' == aug:
scale = 4
interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
mode = random.randint(0,4) % len(interpolation_modes)
# Downsample first, then upsample using the random mode.
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST)
img = cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=mode)
else:
if rand_val < .3:
scale = 1
elif rand_val < .7:
scale = 2
else:
scale = 4
if scale > 1:
interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
mode = random.randint(0,4) % len(interpolation_modes)
# Downsample first, then upsample using the random mode.
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=mode)
def lq_resampling_undo_fn(scale, img):
return cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=cv2.INTER_LINEAR)
undo_fn = functools.partial(lq_resampling_undo_fn, scale)
elif 'color_shift' in aug:
# Color shift
pass
@ -127,8 +160,8 @@ class ImageCorruptor:
if 'noise-5' == aug:
noise_intensity = 5 / 255.0
else:
noise_intensity = (rand_val*4 + 2) / 255.0
img += np.random.randn(*img.shape) * noise_intensity
noise_intensity = (rand_val*6) / 255.0
img += np.random.rand(*img.shape) * noise_intensity
elif 'jpeg' in aug:
if 'noise' not in applied_augmentations and 'noise-5' not in applied_augmentations:
if aug == 'jpeg':
@ -162,7 +195,9 @@ class ImageCorruptor:
# Lightening / saturation
saturation = rand_val * .3
img = np.clip(img + saturation, a_max=1, a_min=0)
elif 'greyscale' in aug:
img = np.tile(np.mean(img, axis=2, keepdims=True), [1,1,3])
elif 'none' not in aug:
raise NotImplementedError("Augmentation doesn't exist")
return img
return img, undo_fn

View File

@ -266,14 +266,14 @@ if __name__ == '__main__':
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked'],
'weights': [1],
'target_size': 256,
'scale': 2,
'scale': 1,
'corrupt_before_downsize': True,
'fetch_alt_image': False,
'fetch_alt_tiled_image': True,
'disable_flip': True,
'fixed_corruptions': [ 'jpeg-medium' ],
'fixed_corruptions': ['lq_resampling', 'jpeg-medium', 'gaussian_blur', 'noise', 'color_jitter'],
'num_corrupts_per_image': 0,
'corruption_blur_scale': 0
'corruption_blur_scale': 1
}
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0, batch_size=64)

View File

@ -1,3 +1,4 @@
import functools
from abc import abstractmethod
import math
@ -702,7 +703,7 @@ class ResNetEncoder(nn.Module):
) -> None:
super(ResNetEncoder, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
norm_layer = functools.partial(nn.GroupNorm, 8)
self._norm_layer = norm_layer
self.inplanes = 64
@ -812,12 +813,10 @@ class UnetWithBuiltInLatentEncoder(nn.Module):
}
super().__init__()
self.encoder = ResNetEncoder(depth=depth_map[kwargs['image_size']])
self.lq_jitter = ColorJitter(.05, .05, .05, .05)
self.unet = SuperResModel(**kwargs)
def forward(self, x, timesteps, alt_hq, low_res=None, **kwargs):
latent = self.encoder(alt_hq)
low_res = self.lq_jitter((low_res+1)/2)*2-1
return self.unet(x, timesteps, latent, low_res, **kwargs)

View File

@ -97,7 +97,8 @@ class Trainer:
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
# torch.autograd.set_detect_anomaly(True)
if opt_get(opt, ['anomaly_detection'], False):
torch.autograd.set_detect_anomaly(True)
# Save the compiled opt dict to the global loaded_options variable.
util.loaded_options = opt

View File

@ -254,6 +254,13 @@ class ExtensibleTrainer(BaseModel):
# And finally perform optimization.
[e.before_optimize(state) for e in self.experiments]
s.do_step(step)
if s.nan_counter > 10:
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
self.save(step)
self.save_training_state(0, step)
raise ArithmeticError
# Call into custom step hooks as well as update EMA params.
for name, net in self.networks.items():
if hasattr(net, "custom_optimizer_step"):

View File

@ -28,6 +28,13 @@ class ConfigurableStep(Module):
self.grads_generated = False
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999
# This is a half-measure that can be used between anomaly_detection and running a potentially problematic
# trainer bare. With this turned on, the optimizer will not step() if a nan grad is detected. If a model trips
# this warning 10 times in a row, the training session is aborted and the model state is saved. This has a
# noticeable affect on training speed, but nowhere near as bad as anomaly_detection.
self.check_grads_for_nan = opt_get(opt_step, ['check_grads_for_nan'], False)
self.nan_counter = 0
self.injectors = []
if 'injectors' in self.step_opt.keys():
injector_names = []
@ -244,8 +251,25 @@ class ConfigurableStep(Module):
before = opt._config['before'] if 'before' in opt._config.keys() else -1
if before != -1 and self.env['step'] > before:
continue
self.scaler.step(opt)
self.scaler.update()
nan_found = False
if self.check_grads_for_nan:
for pg in opt.param_groups:
for p in pg['params']:
if not torch.isfinite(p.grad).any():
nan_found = True
break
if nan_found:
break
if nan_found:
print("NaN found in grads. Throwing this step out.")
self.nan_counter += 1
else:
self.nan_counter = 0
if not nan_found:
self.scaler.step(opt)
self.scaler.update()
def get_metrics(self):
return self.loss_accumulator.as_dict()