Mods to support image classification & filtering

This commit is contained in:
James Betker 2020-12-26 13:49:27 -07:00
parent 10fdfa1563
commit 3fd627fc62
6 changed files with 65 additions and 10 deletions

View File

@ -9,12 +9,12 @@ from io import BytesIO
# options. # options.
class ImageCorruptor: class ImageCorruptor:
def __init__(self, opt): def __init__(self, opt):
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else [] self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0 self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
if self.num_corrupts == 0: if self.num_corrupts == 0:
return return
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else [] self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
def corrupt_images(self, imgs): def corrupt_images(self, imgs):
if self.num_corrupts == 0 and not self.fixed_corruptions: if self.num_corrupts == 0 and not self.fixed_corruptions:
@ -77,7 +77,7 @@ class ImageCorruptor:
scale = 2 scale = 2
if 'lq_resampling4x' == aug: if 'lq_resampling4x' == aug:
scale = 4 scale = 4
interpolation_modes = [cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4] interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
mode = rand_int % len(interpolation_modes) mode = rand_int % len(interpolation_modes)
# Downsample first, then upsample using the random mode. # Downsample first, then upsample using the random mode.
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST) img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST)

View File

@ -193,4 +193,8 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
@register_model @register_model
def register_resnet52(opt_net, opt): def register_resnet52(opt_net, opt):
return resnet50(pretrained=opt_net['pretrained']) model = resnet50(pretrained=opt_net['pretrained'])
if opt_net['custom_head_logits']:
model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits'])
return model

View File

@ -13,15 +13,15 @@ import torch
def main(): def main():
split_img = False split_img = False
opt = {} opt = {}
opt['n_thread'] = 8 opt['n_thread'] = 4
opt['compression_level'] = 90 # JPEG compression quality rating. opt['compression_level'] = 90 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed. # compression time. If read raw images during training, use 0 for faster IO speed.
opt['dest'] = 'file' opt['dest'] = 'file'
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'] opt['input_folder'] = ['F:\\4k6k\\datasets\\images\\youtube\\images']
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\256_unsupervised' opt['save_folder'] = 'F:\\4k6k\\datasets\\images\\ge_full_1024'
opt['imgsize'] = 256 opt['imgsize'] = 1024
#opt['bottom_crop'] = 120 #opt['bottom_crop'] = 120
save_folder = opt['save_folder'] save_folder = opt['save_folder']
@ -61,7 +61,7 @@ class TiledDataset(data.Dataset):
h, w, c = img.shape h, w, c = img.shape
# Uncomment to filter any image that doesnt meet a threshold size. # Uncomment to filter any image that doesnt meet a threshold size.
if min(h,w) < 512: if min(h,w) < 1024:
return None return None
# We must convert the image into a square. # We must convert the image into a square.

View File

@ -293,7 +293,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_diffimage.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,6 +1,8 @@
import os
import random import random
import torch.nn import torch.nn
import torchvision
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from utils.weight_scheduler import get_scheduler_for_opt from utils.weight_scheduler import get_scheduler_for_opt
@ -53,6 +55,10 @@ def create_injector(opt_inject, env):
return SrDiffsInjector(opt_inject, env) return SrDiffsInjector(opt_inject, env)
elif type == 'multiframe_combiner': elif type == 'multiframe_combiner':
return MultiFrameCombiner(opt_inject, env) return MultiFrameCombiner(opt_inject, env)
elif type == 'mix_and_label':
return MixAndLabelInjector(opt_inject, env)
elif type == 'save_images':
return SaveImages(opt_inject, env)
else: else:
raise NotImplementedError raise NotImplementedError
@ -409,3 +415,48 @@ class MultiFrameCombiner(Injector):
return self.combine(state) return self.combine(state)
else: else:
raise NotImplementedError raise NotImplementedError
# Combines data from multiple different sources and mixes them along the batch dimension. Labels are then emitted
# according to how the mixing was performed.
class MixAndLabelInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.out_labels = opt['out_labels']
def forward(self, state):
input_tensors = [state[i] for i in self.input]
num_inputs = len(input_tensors)
bs = input_tensors[0].shape[0]
labels = torch.randint(0, num_inputs, (bs,), device=input_tensors[0].device)
# Still don't know of a good way to do this in torch.. TODO make it better..
res = []
for b in range(bs):
res.append(input_tensors[labels[b]][b, :, :, :])
output = torch.stack(res, dim=0)
return { self.out_labels: labels, self.output: output }
# Doesn't inject. Rather saves images that meet a specified criteria. Useful for performing classification filtering
# using ExtensibleTrainer.
class SaveImages(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.logits = opt['logits']
self.target = opt['target']
self.thresh = opt['threshold']
self.index = 0
self.run_id = random.randint(0, 999999)
self.savedir = opt['savedir']
os.makedirs(self.savedir, exist_ok=True)
self.softmax = torch.nn.Softmax(dim=1)
def forward(self, state):
logits = self.softmax(state[self.logits])
images = state[self.input]
bs = images.shape[0]
for b in range(bs):
if logits[b][self.target] > self.thresh:
torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg'))
self.index += 1
return {}

View File

@ -24,7 +24,7 @@ networks:
type: generator type: generator
which_model_G: glean which_model_G: glean
nf: 64 nf: 64
pretrained_stylegan: ../experiments/stylegan2-ffhq-config-f.pth latent_bank_pretrained_weights: ../experiments/stylegan2-ffhq-config-f.pth
feature_discriminator: feature_discriminator:
type: discriminator type: discriminator