forked from mrq/DL-Art-School
Add meta-anomaly detection, colorjitter augmentation
This commit is contained in:
parent
46e9f62be0
commit
6fd16ea9c8
|
@ -1,8 +1,13 @@
|
|||
import functools
|
||||
import random
|
||||
from math import cos, pi
|
||||
|
||||
import cv2
|
||||
import kornia
|
||||
import numpy as np
|
||||
import torch
|
||||
from kornia.augmentation import ColorJitter
|
||||
|
||||
from data.util import read_img
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
@ -65,29 +70,48 @@ class ImageCorruptor:
|
|||
# Sources of entropy
|
||||
corrupted_imgs = []
|
||||
entropy = []
|
||||
undo_fns = []
|
||||
applied_augs = augmentations + self.fixed_corruptions
|
||||
for img in imgs:
|
||||
for aug in augmentations:
|
||||
r = self.get_rand()
|
||||
img = self.apply_corruption(img, aug, r, applied_augs)
|
||||
img, undo_fn = self.apply_corruption(img, aug, r, applied_augs)
|
||||
if undo_fn is not None:
|
||||
undo_fns.append(undo_fn)
|
||||
for aug in self.fixed_corruptions:
|
||||
r = self.get_rand()
|
||||
img = self.apply_corruption(img, aug, r, applied_augs)
|
||||
img, undo_fn = self.apply_corruption(img, aug, r, applied_augs)
|
||||
entropy.append(r)
|
||||
if undo_fn is not None:
|
||||
undo_fns.append(undo_fn)
|
||||
# Apply undo_fns after all corruptions are finished, in same order.
|
||||
for ufn in undo_fns:
|
||||
img = ufn(img)
|
||||
corrupted_imgs.append(img)
|
||||
|
||||
|
||||
if return_entropy:
|
||||
return corrupted_imgs, entropy
|
||||
else:
|
||||
return corrupted_imgs
|
||||
|
||||
def apply_corruption(self, img, aug, rand_val, applied_augmentations):
|
||||
undo_fn = None
|
||||
if 'color_quantization' in aug:
|
||||
# Color quantization
|
||||
quant_div = 2 ** (int(rand_val * 10 / 3) + 2)
|
||||
img = img * 255
|
||||
img = (img // quant_div) * quant_div
|
||||
img = img / 255
|
||||
elif 'color_jitter' in aug:
|
||||
lo_end = 0
|
||||
hi_end = .2
|
||||
setting = rand_val * (hi_end - lo_end) + lo_end
|
||||
if setting * 255 > 1:
|
||||
# I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format.
|
||||
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
|
||||
img = ColorJitter(setting, setting, setting, setting)(img)
|
||||
img = img.squeeze(0).permute(1,2,0).numpy()
|
||||
elif 'gaussian_blur' in aug:
|
||||
img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5)
|
||||
elif 'motion_blur' in aug:
|
||||
|
@ -105,14 +129,23 @@ class ImageCorruptor:
|
|||
pass
|
||||
elif 'lq_resampling' in aug:
|
||||
# Random mode interpolation HR->LR->HR
|
||||
scale = 2
|
||||
if 'lq_resampling4x' == aug:
|
||||
scale = 4
|
||||
interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
|
||||
mode = random.randint(0,4) % len(interpolation_modes)
|
||||
# Downsample first, then upsample using the random mode.
|
||||
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST)
|
||||
img = cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=mode)
|
||||
else:
|
||||
if rand_val < .3:
|
||||
scale = 1
|
||||
elif rand_val < .7:
|
||||
scale = 2
|
||||
else:
|
||||
scale = 4
|
||||
if scale > 1:
|
||||
interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
|
||||
mode = random.randint(0,4) % len(interpolation_modes)
|
||||
# Downsample first, then upsample using the random mode.
|
||||
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=mode)
|
||||
def lq_resampling_undo_fn(scale, img):
|
||||
return cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=cv2.INTER_LINEAR)
|
||||
undo_fn = functools.partial(lq_resampling_undo_fn, scale)
|
||||
elif 'color_shift' in aug:
|
||||
# Color shift
|
||||
pass
|
||||
|
@ -127,8 +160,8 @@ class ImageCorruptor:
|
|||
if 'noise-5' == aug:
|
||||
noise_intensity = 5 / 255.0
|
||||
else:
|
||||
noise_intensity = (rand_val*4 + 2) / 255.0
|
||||
img += np.random.randn(*img.shape) * noise_intensity
|
||||
noise_intensity = (rand_val*6) / 255.0
|
||||
img += np.random.rand(*img.shape) * noise_intensity
|
||||
elif 'jpeg' in aug:
|
||||
if 'noise' not in applied_augmentations and 'noise-5' not in applied_augmentations:
|
||||
if aug == 'jpeg':
|
||||
|
@ -162,7 +195,9 @@ class ImageCorruptor:
|
|||
# Lightening / saturation
|
||||
saturation = rand_val * .3
|
||||
img = np.clip(img + saturation, a_max=1, a_min=0)
|
||||
elif 'greyscale' in aug:
|
||||
img = np.tile(np.mean(img, axis=2, keepdims=True), [1,1,3])
|
||||
elif 'none' not in aug:
|
||||
raise NotImplementedError("Augmentation doesn't exist")
|
||||
|
||||
return img
|
||||
return img, undo_fn
|
||||
|
|
|
@ -266,14 +266,14 @@ if __name__ == '__main__':
|
|||
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked'],
|
||||
'weights': [1],
|
||||
'target_size': 256,
|
||||
'scale': 2,
|
||||
'scale': 1,
|
||||
'corrupt_before_downsize': True,
|
||||
'fetch_alt_image': False,
|
||||
'fetch_alt_tiled_image': True,
|
||||
'disable_flip': True,
|
||||
'fixed_corruptions': [ 'jpeg-medium' ],
|
||||
'fixed_corruptions': ['lq_resampling', 'jpeg-medium', 'gaussian_blur', 'noise', 'color_jitter'],
|
||||
'num_corrupts_per_image': 0,
|
||||
'corruption_blur_scale': 0
|
||||
'corruption_blur_scale': 1
|
||||
}
|
||||
|
||||
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0, batch_size=64)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import functools
|
||||
from abc import abstractmethod
|
||||
|
||||
import math
|
||||
|
@ -702,7 +703,7 @@ class ResNetEncoder(nn.Module):
|
|||
) -> None:
|
||||
super(ResNetEncoder, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
norm_layer = functools.partial(nn.GroupNorm, 8)
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
|
@ -812,12 +813,10 @@ class UnetWithBuiltInLatentEncoder(nn.Module):
|
|||
}
|
||||
super().__init__()
|
||||
self.encoder = ResNetEncoder(depth=depth_map[kwargs['image_size']])
|
||||
self.lq_jitter = ColorJitter(.05, .05, .05, .05)
|
||||
self.unet = SuperResModel(**kwargs)
|
||||
|
||||
def forward(self, x, timesteps, alt_hq, low_res=None, **kwargs):
|
||||
latent = self.encoder(alt_hq)
|
||||
low_res = self.lq_jitter((low_res+1)/2)*2-1
|
||||
return self.unet(x, timesteps, latent, low_res, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -97,7 +97,8 @@ class Trainer:
|
|||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
if opt_get(opt, ['anomaly_detection'], False):
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
# Save the compiled opt dict to the global loaded_options variable.
|
||||
util.loaded_options = opt
|
||||
|
|
|
@ -254,6 +254,13 @@ class ExtensibleTrainer(BaseModel):
|
|||
# And finally perform optimization.
|
||||
[e.before_optimize(state) for e in self.experiments]
|
||||
s.do_step(step)
|
||||
|
||||
if s.nan_counter > 10:
|
||||
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
|
||||
self.save(step)
|
||||
self.save_training_state(0, step)
|
||||
raise ArithmeticError
|
||||
|
||||
# Call into custom step hooks as well as update EMA params.
|
||||
for name, net in self.networks.items():
|
||||
if hasattr(net, "custom_optimizer_step"):
|
||||
|
|
|
@ -28,6 +28,13 @@ class ConfigurableStep(Module):
|
|||
self.grads_generated = False
|
||||
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999
|
||||
|
||||
# This is a half-measure that can be used between anomaly_detection and running a potentially problematic
|
||||
# trainer bare. With this turned on, the optimizer will not step() if a nan grad is detected. If a model trips
|
||||
# this warning 10 times in a row, the training session is aborted and the model state is saved. This has a
|
||||
# noticeable affect on training speed, but nowhere near as bad as anomaly_detection.
|
||||
self.check_grads_for_nan = opt_get(opt_step, ['check_grads_for_nan'], False)
|
||||
self.nan_counter = 0
|
||||
|
||||
self.injectors = []
|
||||
if 'injectors' in self.step_opt.keys():
|
||||
injector_names = []
|
||||
|
@ -244,8 +251,25 @@ class ConfigurableStep(Module):
|
|||
before = opt._config['before'] if 'before' in opt._config.keys() else -1
|
||||
if before != -1 and self.env['step'] > before:
|
||||
continue
|
||||
self.scaler.step(opt)
|
||||
self.scaler.update()
|
||||
|
||||
nan_found = False
|
||||
if self.check_grads_for_nan:
|
||||
for pg in opt.param_groups:
|
||||
for p in pg['params']:
|
||||
if not torch.isfinite(p.grad).any():
|
||||
nan_found = True
|
||||
break
|
||||
if nan_found:
|
||||
break
|
||||
if nan_found:
|
||||
print("NaN found in grads. Throwing this step out.")
|
||||
self.nan_counter += 1
|
||||
else:
|
||||
self.nan_counter = 0
|
||||
|
||||
if not nan_found:
|
||||
self.scaler.step(opt)
|
||||
self.scaler.update()
|
||||
|
||||
def get_metrics(self):
|
||||
return self.loss_accumulator.as_dict()
|
||||
|
|
Loading…
Reference in New Issue
Block a user