forked from mrq/DL-Art-School
Mods to support image classification & filtering
This commit is contained in:
parent
10fdfa1563
commit
3fd627fc62
|
@ -9,12 +9,12 @@ from io import BytesIO
|
|||
# options.
|
||||
class ImageCorruptor:
|
||||
def __init__(self, opt):
|
||||
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
|
||||
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
|
||||
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
|
||||
if self.num_corrupts == 0:
|
||||
return
|
||||
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
|
||||
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
|
||||
|
||||
def corrupt_images(self, imgs):
|
||||
if self.num_corrupts == 0 and not self.fixed_corruptions:
|
||||
|
@ -77,7 +77,7 @@ class ImageCorruptor:
|
|||
scale = 2
|
||||
if 'lq_resampling4x' == aug:
|
||||
scale = 4
|
||||
interpolation_modes = [cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
|
||||
interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
|
||||
mode = rand_int % len(interpolation_modes)
|
||||
# Downsample first, then upsample using the random mode.
|
||||
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST)
|
||||
|
|
|
@ -193,4 +193,8 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
|||
|
||||
@register_model
|
||||
def register_resnet52(opt_net, opt):
|
||||
return resnet50(pretrained=opt_net['pretrained'])
|
||||
model = resnet50(pretrained=opt_net['pretrained'])
|
||||
if opt_net['custom_head_logits']:
|
||||
model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits'])
|
||||
return model
|
||||
|
||||
|
|
|
@ -13,15 +13,15 @@ import torch
|
|||
def main():
|
||||
split_img = False
|
||||
opt = {}
|
||||
opt['n_thread'] = 8
|
||||
opt['n_thread'] = 4
|
||||
opt['compression_level'] = 90 # JPEG compression quality rating.
|
||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||
|
||||
opt['dest'] = 'file'
|
||||
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised']
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\256_unsupervised'
|
||||
opt['imgsize'] = 256
|
||||
opt['input_folder'] = ['F:\\4k6k\\datasets\\images\\youtube\\images']
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\images\\ge_full_1024'
|
||||
opt['imgsize'] = 1024
|
||||
#opt['bottom_crop'] = 120
|
||||
|
||||
save_folder = opt['save_folder']
|
||||
|
@ -61,7 +61,7 @@ class TiledDataset(data.Dataset):
|
|||
|
||||
h, w, c = img.shape
|
||||
# Uncomment to filter any image that doesnt meet a threshold size.
|
||||
if min(h,w) < 512:
|
||||
if min(h,w) < 1024:
|
||||
return None
|
||||
|
||||
# We must convert the image into a square.
|
||||
|
|
|
@ -293,7 +293,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_diffimage.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import torch.nn
|
||||
import torchvision
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
|
@ -53,6 +55,10 @@ def create_injector(opt_inject, env):
|
|||
return SrDiffsInjector(opt_inject, env)
|
||||
elif type == 'multiframe_combiner':
|
||||
return MultiFrameCombiner(opt_inject, env)
|
||||
elif type == 'mix_and_label':
|
||||
return MixAndLabelInjector(opt_inject, env)
|
||||
elif type == 'save_images':
|
||||
return SaveImages(opt_inject, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -409,3 +415,48 @@ class MultiFrameCombiner(Injector):
|
|||
return self.combine(state)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Combines data from multiple different sources and mixes them along the batch dimension. Labels are then emitted
|
||||
# according to how the mixing was performed.
|
||||
class MixAndLabelInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.out_labels = opt['out_labels']
|
||||
|
||||
def forward(self, state):
|
||||
input_tensors = [state[i] for i in self.input]
|
||||
num_inputs = len(input_tensors)
|
||||
bs = input_tensors[0].shape[0]
|
||||
labels = torch.randint(0, num_inputs, (bs,), device=input_tensors[0].device)
|
||||
# Still don't know of a good way to do this in torch.. TODO make it better..
|
||||
res = []
|
||||
for b in range(bs):
|
||||
res.append(input_tensors[labels[b]][b, :, :, :])
|
||||
output = torch.stack(res, dim=0)
|
||||
return { self.out_labels: labels, self.output: output }
|
||||
|
||||
|
||||
# Doesn't inject. Rather saves images that meet a specified criteria. Useful for performing classification filtering
|
||||
# using ExtensibleTrainer.
|
||||
class SaveImages(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.logits = opt['logits']
|
||||
self.target = opt['target']
|
||||
self.thresh = opt['threshold']
|
||||
self.index = 0
|
||||
self.run_id = random.randint(0, 999999)
|
||||
self.savedir = opt['savedir']
|
||||
os.makedirs(self.savedir, exist_ok=True)
|
||||
self.softmax = torch.nn.Softmax(dim=1)
|
||||
|
||||
def forward(self, state):
|
||||
logits = self.softmax(state[self.logits])
|
||||
images = state[self.input]
|
||||
bs = images.shape[0]
|
||||
for b in range(bs):
|
||||
if logits[b][self.target] > self.thresh:
|
||||
torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg'))
|
||||
self.index += 1
|
||||
return {}
|
|
@ -24,7 +24,7 @@ networks:
|
|||
type: generator
|
||||
which_model_G: glean
|
||||
nf: 64
|
||||
pretrained_stylegan: ../experiments/stylegan2-ffhq-config-f.pth
|
||||
latent_bank_pretrained_weights: ../experiments/stylegan2-ffhq-config-f.pth
|
||||
|
||||
feature_discriminator:
|
||||
type: discriminator
|
||||
|
|
Loading…
Reference in New Issue
Block a user