From 3fd627fc620bf52ed9a2722f39b1c22d950d1005 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 26 Dec 2020 13:49:27 -0700 Subject: [PATCH] Mods to support image classification & filtering --- codes/data/image_corruptor.py | 4 +- codes/models/resnet_with_checkpointing.py | 6 ++- codes/scripts/extract_square_images.py | 10 ++--- codes/train.py | 2 +- codes/trainer/injectors.py | 51 +++++++++++++++++++++++ recipes/glean/train_ffhq_glean.yml | 2 +- 6 files changed, 65 insertions(+), 10 deletions(-) diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index 323ddd09..345423f3 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -9,12 +9,12 @@ from io import BytesIO # options. class ImageCorruptor: def __init__(self, opt): + self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1 self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else [] self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0 if self.num_corrupts == 0: return self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else [] - self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1 def corrupt_images(self, imgs): if self.num_corrupts == 0 and not self.fixed_corruptions: @@ -77,7 +77,7 @@ class ImageCorruptor: scale = 2 if 'lq_resampling4x' == aug: scale = 4 - interpolation_modes = [cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4] + interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4] mode = rand_int % len(interpolation_modes) # Downsample first, then upsample using the random mode. img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST) diff --git a/codes/models/resnet_with_checkpointing.py b/codes/models/resnet_with_checkpointing.py index 4ab8f3c2..94134d77 100644 --- a/codes/models/resnet_with_checkpointing.py +++ b/codes/models/resnet_with_checkpointing.py @@ -193,4 +193,8 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs): @register_model def register_resnet52(opt_net, opt): - return resnet50(pretrained=opt_net['pretrained']) + model = resnet50(pretrained=opt_net['pretrained']) + if opt_net['custom_head_logits']: + model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits']) + return model + diff --git a/codes/scripts/extract_square_images.py b/codes/scripts/extract_square_images.py index 57d61668..09be5bb9 100644 --- a/codes/scripts/extract_square_images.py +++ b/codes/scripts/extract_square_images.py @@ -13,15 +13,15 @@ import torch def main(): split_img = False opt = {} - opt['n_thread'] = 8 + opt['n_thread'] = 4 opt['compression_level'] = 90 # JPEG compression quality rating. # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. opt['dest'] = 'file' - opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'] - opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\256_unsupervised' - opt['imgsize'] = 256 + opt['input_folder'] = ['F:\\4k6k\\datasets\\images\\youtube\\images'] + opt['save_folder'] = 'F:\\4k6k\\datasets\\images\\ge_full_1024' + opt['imgsize'] = 1024 #opt['bottom_crop'] = 120 save_folder = opt['save_folder'] @@ -61,7 +61,7 @@ class TiledDataset(data.Dataset): h, w, c = img.shape # Uncomment to filter any image that doesnt meet a threshold size. - if min(h,w) < 512: + if min(h,w) < 1024: return None # We must convert the image into a square. diff --git a/codes/train.py b/codes/train.py index 86d45886..f3ccab0a 100644 --- a/codes/train.py +++ b/codes/train.py @@ -293,7 +293,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_diffimage.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/injectors.py b/codes/trainer/injectors.py index c8248867..36956bb7 100644 --- a/codes/trainer/injectors.py +++ b/codes/trainer/injectors.py @@ -1,6 +1,8 @@ +import os import random import torch.nn +import torchvision from torch.cuda.amp import autocast from utils.weight_scheduler import get_scheduler_for_opt @@ -53,6 +55,10 @@ def create_injector(opt_inject, env): return SrDiffsInjector(opt_inject, env) elif type == 'multiframe_combiner': return MultiFrameCombiner(opt_inject, env) + elif type == 'mix_and_label': + return MixAndLabelInjector(opt_inject, env) + elif type == 'save_images': + return SaveImages(opt_inject, env) else: raise NotImplementedError @@ -409,3 +415,48 @@ class MultiFrameCombiner(Injector): return self.combine(state) else: raise NotImplementedError + + +# Combines data from multiple different sources and mixes them along the batch dimension. Labels are then emitted +# according to how the mixing was performed. +class MixAndLabelInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.out_labels = opt['out_labels'] + + def forward(self, state): + input_tensors = [state[i] for i in self.input] + num_inputs = len(input_tensors) + bs = input_tensors[0].shape[0] + labels = torch.randint(0, num_inputs, (bs,), device=input_tensors[0].device) + # Still don't know of a good way to do this in torch.. TODO make it better.. + res = [] + for b in range(bs): + res.append(input_tensors[labels[b]][b, :, :, :]) + output = torch.stack(res, dim=0) + return { self.out_labels: labels, self.output: output } + + +# Doesn't inject. Rather saves images that meet a specified criteria. Useful for performing classification filtering +# using ExtensibleTrainer. +class SaveImages(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.logits = opt['logits'] + self.target = opt['target'] + self.thresh = opt['threshold'] + self.index = 0 + self.run_id = random.randint(0, 999999) + self.savedir = opt['savedir'] + os.makedirs(self.savedir, exist_ok=True) + self.softmax = torch.nn.Softmax(dim=1) + + def forward(self, state): + logits = self.softmax(state[self.logits]) + images = state[self.input] + bs = images.shape[0] + for b in range(bs): + if logits[b][self.target] > self.thresh: + torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg')) + self.index += 1 + return {} \ No newline at end of file diff --git a/recipes/glean/train_ffhq_glean.yml b/recipes/glean/train_ffhq_glean.yml index b94e15e6..d1175e2a 100644 --- a/recipes/glean/train_ffhq_glean.yml +++ b/recipes/glean/train_ffhq_glean.yml @@ -24,7 +24,7 @@ networks: type: generator which_model_G: glean nf: 64 - pretrained_stylegan: ../experiments/stylegan2-ffhq-config-f.pth + latent_bank_pretrained_weights: ../experiments/stylegan2-ffhq-config-f.pth feature_discriminator: type: discriminator