diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index 796db99a..6af7d000 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -1,8 +1,13 @@ +import functools import random from math import cos, pi import cv2 +import kornia import numpy as np +import torch +from kornia.augmentation import ColorJitter + from data.util import read_img from PIL import Image from io import BytesIO @@ -65,29 +70,48 @@ class ImageCorruptor: # Sources of entropy corrupted_imgs = [] entropy = [] + undo_fns = [] applied_augs = augmentations + self.fixed_corruptions for img in imgs: for aug in augmentations: r = self.get_rand() - img = self.apply_corruption(img, aug, r, applied_augs) + img, undo_fn = self.apply_corruption(img, aug, r, applied_augs) + if undo_fn is not None: + undo_fns.append(undo_fn) for aug in self.fixed_corruptions: r = self.get_rand() - img = self.apply_corruption(img, aug, r, applied_augs) + img, undo_fn = self.apply_corruption(img, aug, r, applied_augs) entropy.append(r) + if undo_fn is not None: + undo_fns.append(undo_fn) + # Apply undo_fns after all corruptions are finished, in same order. + for ufn in undo_fns: + img = ufn(img) corrupted_imgs.append(img) + if return_entropy: return corrupted_imgs, entropy else: return corrupted_imgs def apply_corruption(self, img, aug, rand_val, applied_augmentations): + undo_fn = None if 'color_quantization' in aug: # Color quantization quant_div = 2 ** (int(rand_val * 10 / 3) + 2) img = img * 255 img = (img // quant_div) * quant_div img = img / 255 + elif 'color_jitter' in aug: + lo_end = 0 + hi_end = .2 + setting = rand_val * (hi_end - lo_end) + lo_end + if setting * 255 > 1: + # I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format. + img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0) + img = ColorJitter(setting, setting, setting, setting)(img) + img = img.squeeze(0).permute(1,2,0).numpy() elif 'gaussian_blur' in aug: img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5) elif 'motion_blur' in aug: @@ -105,14 +129,23 @@ class ImageCorruptor: pass elif 'lq_resampling' in aug: # Random mode interpolation HR->LR->HR - scale = 2 if 'lq_resampling4x' == aug: scale = 4 - interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4] - mode = random.randint(0,4) % len(interpolation_modes) - # Downsample first, then upsample using the random mode. - img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST) - img = cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=mode) + else: + if rand_val < .3: + scale = 1 + elif rand_val < .7: + scale = 2 + else: + scale = 4 + if scale > 1: + interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4] + mode = random.randint(0,4) % len(interpolation_modes) + # Downsample first, then upsample using the random mode. + img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=mode) + def lq_resampling_undo_fn(scale, img): + return cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=cv2.INTER_LINEAR) + undo_fn = functools.partial(lq_resampling_undo_fn, scale) elif 'color_shift' in aug: # Color shift pass @@ -127,8 +160,8 @@ class ImageCorruptor: if 'noise-5' == aug: noise_intensity = 5 / 255.0 else: - noise_intensity = (rand_val*4 + 2) / 255.0 - img += np.random.randn(*img.shape) * noise_intensity + noise_intensity = (rand_val*6) / 255.0 + img += np.random.rand(*img.shape) * noise_intensity elif 'jpeg' in aug: if 'noise' not in applied_augmentations and 'noise-5' not in applied_augmentations: if aug == 'jpeg': @@ -162,7 +195,9 @@ class ImageCorruptor: # Lightening / saturation saturation = rand_val * .3 img = np.clip(img + saturation, a_max=1, a_min=0) + elif 'greyscale' in aug: + img = np.tile(np.mean(img, axis=2, keepdims=True), [1,1,3]) elif 'none' not in aug: raise NotImplementedError("Augmentation doesn't exist") - return img + return img, undo_fn diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index 157b4007..e263d8f5 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -266,14 +266,14 @@ if __name__ == '__main__': 'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked'], 'weights': [1], 'target_size': 256, - 'scale': 2, + 'scale': 1, 'corrupt_before_downsize': True, 'fetch_alt_image': False, 'fetch_alt_tiled_image': True, 'disable_flip': True, - 'fixed_corruptions': [ 'jpeg-medium' ], + 'fixed_corruptions': ['lq_resampling', 'jpeg-medium', 'gaussian_blur', 'noise', 'color_jitter'], 'num_corrupts_per_image': 0, - 'corruption_blur_scale': 0 + 'corruption_blur_scale': 1 } ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0, batch_size=64) diff --git a/codes/models/diffusion/unet_latent_guide.py b/codes/models/diffusion/unet_latent_guide.py index 4c1881ac..41dd85a8 100644 --- a/codes/models/diffusion/unet_latent_guide.py +++ b/codes/models/diffusion/unet_latent_guide.py @@ -1,3 +1,4 @@ +import functools from abc import abstractmethod import math @@ -702,7 +703,7 @@ class ResNetEncoder(nn.Module): ) -> None: super(ResNetEncoder, self).__init__() if norm_layer is None: - norm_layer = nn.BatchNorm2d + norm_layer = functools.partial(nn.GroupNorm, 8) self._norm_layer = norm_layer self.inplanes = 64 @@ -812,12 +813,10 @@ class UnetWithBuiltInLatentEncoder(nn.Module): } super().__init__() self.encoder = ResNetEncoder(depth=depth_map[kwargs['image_size']]) - self.lq_jitter = ColorJitter(.05, .05, .05, .05) self.unet = SuperResModel(**kwargs) def forward(self, x, timesteps, alt_hq, low_res=None, **kwargs): latent = self.encoder(alt_hq) - low_res = self.lq_jitter((low_res+1)/2)*2-1 return self.unet(x, timesteps, latent, low_res, **kwargs) diff --git a/codes/train.py b/codes/train.py index df293997..fbc2906a 100644 --- a/codes/train.py +++ b/codes/train.py @@ -97,7 +97,8 @@ class Trainer: torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True - # torch.autograd.set_detect_anomaly(True) + if opt_get(opt, ['anomaly_detection'], False): + torch.autograd.set_detect_anomaly(True) # Save the compiled opt dict to the global loaded_options variable. util.loaded_options = opt diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index aa0bee3e..7da22bdd 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -254,6 +254,13 @@ class ExtensibleTrainer(BaseModel): # And finally perform optimization. [e.before_optimize(state) for e in self.experiments] s.do_step(step) + + if s.nan_counter > 10: + print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.") + self.save(step) + self.save_training_state(0, step) + raise ArithmeticError + # Call into custom step hooks as well as update EMA params. for name, net in self.networks.items(): if hasattr(net, "custom_optimizer_step"): diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index b8d52f9e..cc50b9c2 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -28,6 +28,13 @@ class ConfigurableStep(Module): self.grads_generated = False self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999 + # This is a half-measure that can be used between anomaly_detection and running a potentially problematic + # trainer bare. With this turned on, the optimizer will not step() if a nan grad is detected. If a model trips + # this warning 10 times in a row, the training session is aborted and the model state is saved. This has a + # noticeable affect on training speed, but nowhere near as bad as anomaly_detection. + self.check_grads_for_nan = opt_get(opt_step, ['check_grads_for_nan'], False) + self.nan_counter = 0 + self.injectors = [] if 'injectors' in self.step_opt.keys(): injector_names = [] @@ -244,8 +251,25 @@ class ConfigurableStep(Module): before = opt._config['before'] if 'before' in opt._config.keys() else -1 if before != -1 and self.env['step'] > before: continue - self.scaler.step(opt) - self.scaler.update() + + nan_found = False + if self.check_grads_for_nan: + for pg in opt.param_groups: + for p in pg['params']: + if not torch.isfinite(p.grad).any(): + nan_found = True + break + if nan_found: + break + if nan_found: + print("NaN found in grads. Throwing this step out.") + self.nan_counter += 1 + else: + self.nan_counter = 0 + + if not nan_found: + self.scaler.step(opt) + self.scaler.update() def get_metrics(self): return self.loss_accumulator.as_dict()