Merge remote-tracking branch 'origin/gan_lab' into gan_lab
This commit is contained in:
commit
2706a84f15
|
@ -9,12 +9,12 @@ from io import BytesIO
|
||||||
# options.
|
# options.
|
||||||
class ImageCorruptor:
|
class ImageCorruptor:
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
|
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
|
||||||
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
|
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
|
||||||
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
|
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
|
||||||
if self.num_corrupts == 0:
|
if self.num_corrupts == 0:
|
||||||
return
|
return
|
||||||
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
|
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
|
||||||
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
|
|
||||||
|
|
||||||
def corrupt_images(self, imgs):
|
def corrupt_images(self, imgs):
|
||||||
if self.num_corrupts == 0 and not self.fixed_corruptions:
|
if self.num_corrupts == 0 and not self.fixed_corruptions:
|
||||||
|
@ -77,7 +77,7 @@ class ImageCorruptor:
|
||||||
scale = 2
|
scale = 2
|
||||||
if 'lq_resampling4x' == aug:
|
if 'lq_resampling4x' == aug:
|
||||||
scale = 4
|
scale = 4
|
||||||
interpolation_modes = [cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
|
interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
|
||||||
mode = rand_int % len(interpolation_modes)
|
mode = rand_int % len(interpolation_modes)
|
||||||
# Downsample first, then upsample using the random mode.
|
# Downsample first, then upsample using the random mode.
|
||||||
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST)
|
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
|
@ -193,4 +193,8 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_resnet52(opt_net, opt):
|
def register_resnet52(opt_net, opt):
|
||||||
return resnet50(pretrained=opt_net['pretrained'])
|
model = resnet50(pretrained=opt_net['pretrained'])
|
||||||
|
if opt_net['custom_head_logits']:
|
||||||
|
model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits'])
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
|
@ -13,15 +13,15 @@ import torch
|
||||||
def main():
|
def main():
|
||||||
split_img = False
|
split_img = False
|
||||||
opt = {}
|
opt = {}
|
||||||
opt['n_thread'] = 8
|
opt['n_thread'] = 4
|
||||||
opt['compression_level'] = 90 # JPEG compression quality rating.
|
opt['compression_level'] = 90 # JPEG compression quality rating.
|
||||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||||
|
|
||||||
opt['dest'] = 'file'
|
opt['dest'] = 'file'
|
||||||
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised']
|
opt['input_folder'] = ['F:\\4k6k\\datasets\\images\\youtube\\images']
|
||||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\256_unsupervised'
|
opt['save_folder'] = 'F:\\4k6k\\datasets\\images\\ge_full_1024'
|
||||||
opt['imgsize'] = 256
|
opt['imgsize'] = 1024
|
||||||
#opt['bottom_crop'] = 120
|
#opt['bottom_crop'] = 120
|
||||||
|
|
||||||
save_folder = opt['save_folder']
|
save_folder = opt['save_folder']
|
||||||
|
@ -61,7 +61,7 @@ class TiledDataset(data.Dataset):
|
||||||
|
|
||||||
h, w, c = img.shape
|
h, w, c = img.shape
|
||||||
# Uncomment to filter any image that doesnt meet a threshold size.
|
# Uncomment to filter any image that doesnt meet a threshold size.
|
||||||
if min(h,w) < 512:
|
if min(h,w) < 1024:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# We must convert the image into a square.
|
# We must convert the image into a square.
|
||||||
|
|
|
@ -293,7 +293,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_diffimage.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch.nn
|
import torch.nn
|
||||||
|
import torchvision
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from utils.weight_scheduler import get_scheduler_for_opt
|
from utils.weight_scheduler import get_scheduler_for_opt
|
||||||
|
@ -53,6 +55,10 @@ def create_injector(opt_inject, env):
|
||||||
return SrDiffsInjector(opt_inject, env)
|
return SrDiffsInjector(opt_inject, env)
|
||||||
elif type == 'multiframe_combiner':
|
elif type == 'multiframe_combiner':
|
||||||
return MultiFrameCombiner(opt_inject, env)
|
return MultiFrameCombiner(opt_inject, env)
|
||||||
|
elif type == 'mix_and_label':
|
||||||
|
return MixAndLabelInjector(opt_inject, env)
|
||||||
|
elif type == 'save_images':
|
||||||
|
return SaveImages(opt_inject, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -409,3 +415,48 @@ class MultiFrameCombiner(Injector):
|
||||||
return self.combine(state)
|
return self.combine(state)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
# Combines data from multiple different sources and mixes them along the batch dimension. Labels are then emitted
|
||||||
|
# according to how the mixing was performed.
|
||||||
|
class MixAndLabelInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.out_labels = opt['out_labels']
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
input_tensors = [state[i] for i in self.input]
|
||||||
|
num_inputs = len(input_tensors)
|
||||||
|
bs = input_tensors[0].shape[0]
|
||||||
|
labels = torch.randint(0, num_inputs, (bs,), device=input_tensors[0].device)
|
||||||
|
# Still don't know of a good way to do this in torch.. TODO make it better..
|
||||||
|
res = []
|
||||||
|
for b in range(bs):
|
||||||
|
res.append(input_tensors[labels[b]][b, :, :, :])
|
||||||
|
output = torch.stack(res, dim=0)
|
||||||
|
return { self.out_labels: labels, self.output: output }
|
||||||
|
|
||||||
|
|
||||||
|
# Doesn't inject. Rather saves images that meet a specified criteria. Useful for performing classification filtering
|
||||||
|
# using ExtensibleTrainer.
|
||||||
|
class SaveImages(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.logits = opt['logits']
|
||||||
|
self.target = opt['target']
|
||||||
|
self.thresh = opt['threshold']
|
||||||
|
self.index = 0
|
||||||
|
self.run_id = random.randint(0, 999999)
|
||||||
|
self.savedir = opt['savedir']
|
||||||
|
os.makedirs(self.savedir, exist_ok=True)
|
||||||
|
self.softmax = torch.nn.Softmax(dim=1)
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
logits = self.softmax(state[self.logits])
|
||||||
|
images = state[self.input]
|
||||||
|
bs = images.shape[0]
|
||||||
|
for b in range(bs):
|
||||||
|
if logits[b][self.target] > self.thresh:
|
||||||
|
torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg'))
|
||||||
|
self.index += 1
|
||||||
|
return {}
|
|
@ -24,7 +24,7 @@ networks:
|
||||||
type: generator
|
type: generator
|
||||||
which_model_G: glean
|
which_model_G: glean
|
||||||
nf: 64
|
nf: 64
|
||||||
pretrained_stylegan: ../experiments/stylegan2-ffhq-config-f.pth
|
latent_bank_pretrained_weights: ../experiments/stylegan2-ffhq-config-f.pth
|
||||||
|
|
||||||
feature_discriminator:
|
feature_discriminator:
|
||||||
type: discriminator
|
type: discriminator
|
||||||
|
|
Loading…
Reference in New Issue
Block a user