Add meta-anomaly detection, colorjitter augmentation

This commit is contained in:
James Betker 2021-06-29 13:41:55 -06:00
parent 46e9f62be0
commit 6fd16ea9c8
6 changed files with 86 additions and 20 deletions

View File

@ -1,8 +1,13 @@
import functools
import random import random
from math import cos, pi from math import cos, pi
import cv2 import cv2
import kornia
import numpy as np import numpy as np
import torch
from kornia.augmentation import ColorJitter
from data.util import read_img from data.util import read_img
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -65,29 +70,48 @@ class ImageCorruptor:
# Sources of entropy # Sources of entropy
corrupted_imgs = [] corrupted_imgs = []
entropy = [] entropy = []
undo_fns = []
applied_augs = augmentations + self.fixed_corruptions applied_augs = augmentations + self.fixed_corruptions
for img in imgs: for img in imgs:
for aug in augmentations: for aug in augmentations:
r = self.get_rand() r = self.get_rand()
img = self.apply_corruption(img, aug, r, applied_augs) img, undo_fn = self.apply_corruption(img, aug, r, applied_augs)
if undo_fn is not None:
undo_fns.append(undo_fn)
for aug in self.fixed_corruptions: for aug in self.fixed_corruptions:
r = self.get_rand() r = self.get_rand()
img = self.apply_corruption(img, aug, r, applied_augs) img, undo_fn = self.apply_corruption(img, aug, r, applied_augs)
entropy.append(r) entropy.append(r)
if undo_fn is not None:
undo_fns.append(undo_fn)
# Apply undo_fns after all corruptions are finished, in same order.
for ufn in undo_fns:
img = ufn(img)
corrupted_imgs.append(img) corrupted_imgs.append(img)
if return_entropy: if return_entropy:
return corrupted_imgs, entropy return corrupted_imgs, entropy
else: else:
return corrupted_imgs return corrupted_imgs
def apply_corruption(self, img, aug, rand_val, applied_augmentations): def apply_corruption(self, img, aug, rand_val, applied_augmentations):
undo_fn = None
if 'color_quantization' in aug: if 'color_quantization' in aug:
# Color quantization # Color quantization
quant_div = 2 ** (int(rand_val * 10 / 3) + 2) quant_div = 2 ** (int(rand_val * 10 / 3) + 2)
img = img * 255 img = img * 255
img = (img // quant_div) * quant_div img = (img // quant_div) * quant_div
img = img / 255 img = img / 255
elif 'color_jitter' in aug:
lo_end = 0
hi_end = .2
setting = rand_val * (hi_end - lo_end) + lo_end
if setting * 255 > 1:
# I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format.
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
img = ColorJitter(setting, setting, setting, setting)(img)
img = img.squeeze(0).permute(1,2,0).numpy()
elif 'gaussian_blur' in aug: elif 'gaussian_blur' in aug:
img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5) img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5)
elif 'motion_blur' in aug: elif 'motion_blur' in aug:
@ -105,14 +129,23 @@ class ImageCorruptor:
pass pass
elif 'lq_resampling' in aug: elif 'lq_resampling' in aug:
# Random mode interpolation HR->LR->HR # Random mode interpolation HR->LR->HR
scale = 2
if 'lq_resampling4x' == aug: if 'lq_resampling4x' == aug:
scale = 4 scale = 4
interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4] else:
mode = random.randint(0,4) % len(interpolation_modes) if rand_val < .3:
# Downsample first, then upsample using the random mode. scale = 1
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST) elif rand_val < .7:
img = cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=mode) scale = 2
else:
scale = 4
if scale > 1:
interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
mode = random.randint(0,4) % len(interpolation_modes)
# Downsample first, then upsample using the random mode.
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=mode)
def lq_resampling_undo_fn(scale, img):
return cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=cv2.INTER_LINEAR)
undo_fn = functools.partial(lq_resampling_undo_fn, scale)
elif 'color_shift' in aug: elif 'color_shift' in aug:
# Color shift # Color shift
pass pass
@ -127,8 +160,8 @@ class ImageCorruptor:
if 'noise-5' == aug: if 'noise-5' == aug:
noise_intensity = 5 / 255.0 noise_intensity = 5 / 255.0
else: else:
noise_intensity = (rand_val*4 + 2) / 255.0 noise_intensity = (rand_val*6) / 255.0
img += np.random.randn(*img.shape) * noise_intensity img += np.random.rand(*img.shape) * noise_intensity
elif 'jpeg' in aug: elif 'jpeg' in aug:
if 'noise' not in applied_augmentations and 'noise-5' not in applied_augmentations: if 'noise' not in applied_augmentations and 'noise-5' not in applied_augmentations:
if aug == 'jpeg': if aug == 'jpeg':
@ -162,7 +195,9 @@ class ImageCorruptor:
# Lightening / saturation # Lightening / saturation
saturation = rand_val * .3 saturation = rand_val * .3
img = np.clip(img + saturation, a_max=1, a_min=0) img = np.clip(img + saturation, a_max=1, a_min=0)
elif 'greyscale' in aug:
img = np.tile(np.mean(img, axis=2, keepdims=True), [1,1,3])
elif 'none' not in aug: elif 'none' not in aug:
raise NotImplementedError("Augmentation doesn't exist") raise NotImplementedError("Augmentation doesn't exist")
return img return img, undo_fn

View File

@ -266,14 +266,14 @@ if __name__ == '__main__':
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked'], 'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked'],
'weights': [1], 'weights': [1],
'target_size': 256, 'target_size': 256,
'scale': 2, 'scale': 1,
'corrupt_before_downsize': True, 'corrupt_before_downsize': True,
'fetch_alt_image': False, 'fetch_alt_image': False,
'fetch_alt_tiled_image': True, 'fetch_alt_tiled_image': True,
'disable_flip': True, 'disable_flip': True,
'fixed_corruptions': [ 'jpeg-medium' ], 'fixed_corruptions': ['lq_resampling', 'jpeg-medium', 'gaussian_blur', 'noise', 'color_jitter'],
'num_corrupts_per_image': 0, 'num_corrupts_per_image': 0,
'corruption_blur_scale': 0 'corruption_blur_scale': 1
} }
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0, batch_size=64) ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0, batch_size=64)

View File

@ -1,3 +1,4 @@
import functools
from abc import abstractmethod from abc import abstractmethod
import math import math
@ -702,7 +703,7 @@ class ResNetEncoder(nn.Module):
) -> None: ) -> None:
super(ResNetEncoder, self).__init__() super(ResNetEncoder, self).__init__()
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = functools.partial(nn.GroupNorm, 8)
self._norm_layer = norm_layer self._norm_layer = norm_layer
self.inplanes = 64 self.inplanes = 64
@ -812,12 +813,10 @@ class UnetWithBuiltInLatentEncoder(nn.Module):
} }
super().__init__() super().__init__()
self.encoder = ResNetEncoder(depth=depth_map[kwargs['image_size']]) self.encoder = ResNetEncoder(depth=depth_map[kwargs['image_size']])
self.lq_jitter = ColorJitter(.05, .05, .05, .05)
self.unet = SuperResModel(**kwargs) self.unet = SuperResModel(**kwargs)
def forward(self, x, timesteps, alt_hq, low_res=None, **kwargs): def forward(self, x, timesteps, alt_hq, low_res=None, **kwargs):
latent = self.encoder(alt_hq) latent = self.encoder(alt_hq)
low_res = self.lq_jitter((low_res+1)/2)*2-1
return self.unet(x, timesteps, latent, low_res, **kwargs) return self.unet(x, timesteps, latent, low_res, **kwargs)

View File

@ -97,7 +97,8 @@ class Trainer:
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True # torch.backends.cudnn.deterministic = True
# torch.autograd.set_detect_anomaly(True) if opt_get(opt, ['anomaly_detection'], False):
torch.autograd.set_detect_anomaly(True)
# Save the compiled opt dict to the global loaded_options variable. # Save the compiled opt dict to the global loaded_options variable.
util.loaded_options = opt util.loaded_options = opt

View File

@ -254,6 +254,13 @@ class ExtensibleTrainer(BaseModel):
# And finally perform optimization. # And finally perform optimization.
[e.before_optimize(state) for e in self.experiments] [e.before_optimize(state) for e in self.experiments]
s.do_step(step) s.do_step(step)
if s.nan_counter > 10:
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
self.save(step)
self.save_training_state(0, step)
raise ArithmeticError
# Call into custom step hooks as well as update EMA params. # Call into custom step hooks as well as update EMA params.
for name, net in self.networks.items(): for name, net in self.networks.items():
if hasattr(net, "custom_optimizer_step"): if hasattr(net, "custom_optimizer_step"):

View File

@ -28,6 +28,13 @@ class ConfigurableStep(Module):
self.grads_generated = False self.grads_generated = False
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999 self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999
# This is a half-measure that can be used between anomaly_detection and running a potentially problematic
# trainer bare. With this turned on, the optimizer will not step() if a nan grad is detected. If a model trips
# this warning 10 times in a row, the training session is aborted and the model state is saved. This has a
# noticeable affect on training speed, but nowhere near as bad as anomaly_detection.
self.check_grads_for_nan = opt_get(opt_step, ['check_grads_for_nan'], False)
self.nan_counter = 0
self.injectors = [] self.injectors = []
if 'injectors' in self.step_opt.keys(): if 'injectors' in self.step_opt.keys():
injector_names = [] injector_names = []
@ -244,8 +251,25 @@ class ConfigurableStep(Module):
before = opt._config['before'] if 'before' in opt._config.keys() else -1 before = opt._config['before'] if 'before' in opt._config.keys() else -1
if before != -1 and self.env['step'] > before: if before != -1 and self.env['step'] > before:
continue continue
self.scaler.step(opt)
self.scaler.update() nan_found = False
if self.check_grads_for_nan:
for pg in opt.param_groups:
for p in pg['params']:
if not torch.isfinite(p.grad).any():
nan_found = True
break
if nan_found:
break
if nan_found:
print("NaN found in grads. Throwing this step out.")
self.nan_counter += 1
else:
self.nan_counter = 0
if not nan_found:
self.scaler.step(opt)
self.scaler.update()
def get_metrics(self): def get_metrics(self):
return self.loss_accumulator.as_dict() return self.loss_accumulator.as_dict()