Various mods to support better jpeg image filtering

This commit is contained in:
James Betker 2021-06-25 13:16:15 -06:00
parent e7890dc0ba
commit a57ed8e960
8 changed files with 186 additions and 41 deletions

View File

@ -9,6 +9,8 @@ from io import BytesIO
# Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data) # Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data)
from utils.util import opt_get
''' '''
if __name__ == '__main__': if __name__ == '__main__':
import numpy as np import numpy as np
@ -23,13 +25,15 @@ if __name__ == '__main__':
class ImageCorruptor: class ImageCorruptor:
def __init__(self, opt): def __init__(self, opt):
self.opt = opt self.opt = opt
self.reset_random()
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1 self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else [] self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0 self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
self.cosine_bias = opt_get(opt, ['cosine_bias'], True)
if self.num_corrupts == 0: if self.num_corrupts == 0:
return return
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else [] else:
self.reset_random() self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
def reset_random(self): def reset_random(self):
if 'random_seed' in self.opt.keys(): if 'random_seed' in self.opt.keys():
@ -41,7 +45,10 @@ class ImageCorruptor:
# Return is on [0,1] with a bias towards 0. # Return is on [0,1] with a bias towards 0.
def get_rand(self): def get_rand(self):
r = self.rand.random() r = self.rand.random()
return 1 - cos(r * pi / 2) if self.cosine_bias:
return 1 - cos(r * pi / 2)
else:
return r
def corrupt_images(self, imgs, return_entropy=False): def corrupt_images(self, imgs, return_entropy=False):
if self.num_corrupts == 0 and not self.fixed_corruptions: if self.num_corrupts == 0 and not self.fixed_corruptions:
@ -82,10 +89,10 @@ class ImageCorruptor:
img = (img // quant_div) * quant_div img = (img // quant_div) * quant_div
img = img / 255 img = img / 255
elif 'gaussian_blur' in aug: elif 'gaussian_blur' in aug:
img = cv2.GaussianBlur(img, (0,0), rand_val*1.5) img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5)
elif 'motion_blur' in aug: elif 'motion_blur' in aug:
# Motion blur # Motion blur
intensity = self.blur_scale * rand_val * 3 + 1 intensity = self.blur_scale*rand_val * 3 + 1
angle = random.randint(0,360) angle = random.randint(0,360)
k = np.zeros((intensity, intensity), dtype=np.float32) k = np.zeros((intensity, intensity), dtype=np.float32)
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32) k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)

View File

@ -1,3 +1,4 @@
import functools
import glob import glob
import itertools import itertools
import random import random
@ -11,7 +12,7 @@ import os
import torchvision import torchvision
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.transforms import Normalize from torchvision.transforms import Normalize, CenterCrop
from tqdm import tqdm from tqdm import tqdm
from data import util from data import util
@ -21,10 +22,19 @@ from data.image_label_parser import VsNetImageLabeler
from utils.util import opt_get from utils.util import opt_get
def ndarray_center_crop(crop, img):
y, x, c = img.shape
startx = x // 2 - crop // 2
starty = y // 2 - crop // 2
return img[starty:starty + crop, startx:startx + crop, :]
class ImageFolderDataset: class ImageFolderDataset:
def __init__(self, opt): def __init__(self, opt):
self.opt = opt self.opt = opt
self.corruptor = ImageCorruptor(opt) self.corruptor = ImageCorruptor(opt)
if 'center_crop_hq_sz' in opt.keys():
self.center_crop = functools.partial(ndarray_center_crop, opt['center_crop_hq_sz'])
self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None
self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1 self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1
self.scale = opt['scale'] self.scale = opt['scale']
@ -132,6 +142,8 @@ class ImageFolderDataset:
def __getitem__(self, item): def __getitem__(self, item):
hq = util.read_img(None, self.image_paths[item], rgb=True) hq = util.read_img(None, self.image_paths[item], rgb=True)
if hasattr(self, 'center_crop'):
hq = self.center_crop(hq)
if not self.disable_flip and random.random() < .5: if not self.disable_flip and random.random() < .5:
hq = hq[:, ::-1, :] hq = hq[:, ::-1, :]
@ -223,25 +235,41 @@ class ImageFolderDataset:
if __name__ == '__main__': if __name__ == '__main__':
opt = { opt = {
'name': 'amalgam', 'name': 'amalgam',
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], 'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked'],
'weights': [1], 'weights': [1],
'target_size': 256, 'target_size': 256,
'force_multiple': 1,
'scale': 2, 'scale': 2,
'corrupt_before_downsize': True, 'corrupt_before_downsize': True,
'fetch_alt_image': False, 'fetch_alt_image': False,
'disable_flip': True, 'disable_flip': True,
'fixed_corruptions': [ 'jpeg-broad', 'gaussian_blur' ], 'fixed_corruptions': [ 'jpeg-medium' ],
'num_corrupts_per_image': 0, 'num_corrupts_per_image': 0,
'corruption_blur_scale': 0 'corruption_blur_scale': 0
} }
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0) ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=4, batch_size=64)
import os import os
output_path = 'E:\\4k6k\\datasets\\ns_images\\128_unsupervised' output_path = 'F:\\tmp'
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
res = []
for i, d in tqdm(enumerate(ds)): for i, d in tqdm(enumerate(ds)):
lq = d['lq'] '''
#torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png') x = d['hq']
b,c,h,w = x.shape
x_c = x.view(c*b, h, w)
x_c = torch.view_as_real(torch.fft.rfft(x_c))
# Log-normalize spectrogram
x_c = (x_c.abs() ** 2).clip(min=1e-8, max=1e16)
x_c = torch.log(x_c)
res.append(x_c)
if i % 100 == 99:
stacked = torch.cat(res, dim=0)
print(stacked.mean(dim=[0,1,2]), stacked.std(dim=[0,1,2]))
'''
for k, v in d.items():
if isinstance(v, torch.Tensor) and len(v.shape) >= 3:
os.makedirs(f'{output_path}\\{k}', exist_ok=True)
torchvision.utils.save_image(v, f'{output_path}\\{k}\\{i}.png')
if i >= 200000: if i >= 200000:
break break

View File

@ -0,0 +1,11 @@
from torchvision.models import vgg16
from trainer.networks import register_model
from utils.util import opt_get
@register_model
def register_torch_vgg16(opt_net, opt):
""" return a ResNet 18 object
"""
return vgg16(**opt_get(opt_net, ['kwargs'], {}))

View File

@ -0,0 +1,86 @@
import torch
import torch.nn as nn
from trainer.networks import register_model
from utils.util import opt_get
class WideKernelVgg(nn.Module):
def __init__(self, nf=64, num_classes=2):
super().__init__()
self.net = nn.Sequential(
# [64, 128, 128]
nn.Conv2d(6, nf, 7, 1, 3, bias=True),
nn.BatchNorm2d(nf, affine=True),
nn.ReLU(),
nn.Conv2d(nf, nf, 7, 1, 3, bias=False),
nn.BatchNorm2d(nf, affine=True),
nn.ReLU(),
nn.Conv2d(nf, nf, 5, 2, 2, bias=False),
nn.BatchNorm2d(nf, affine=True),
nn.ReLU(),
# [64, 64, 64]
nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False),
nn.BatchNorm2d(nf * 2, affine=True),
nn.ReLU(),
nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(nf * 2, affine=True),
nn.ReLU(),
# [128, 32, 32]
nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False),
nn.BatchNorm2d(nf * 4, affine=True),
nn.ReLU(),
nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(nf * 4, affine=True),
nn.ReLU(),
# [256, 16, 16]
nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False),
nn.BatchNorm2d(nf * 8, affine=True),
nn.ReLU(),
nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(nf * 8, affine=True),
nn.ReLU(),
# [512, 8, 8]
nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False),
nn.BatchNorm2d(nf * 8, affine=True),
nn.ReLU(),
nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(nf * 8, affine=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(nf * 8 * 4 * 2, 100),
nn.ReLU(),
nn.Linear(100, num_classes)
)
# These normalization constants should be derived experimentally.
self.log_fft_mean = torch.tensor([-3.5184, -4.071]).view(1,1,1,2)
self.log_fft_std = torch.tensor([3.1660, 3.8042]).view(1,1,1,2)
def forward(self, x):
b,c,h,w = x.shape
x_c = x.view(c*b, h, w)
x_c = torch.view_as_real(torch.fft.rfft(x_c))
# Log-normalize spectrogram
x_c = (x_c.abs() ** 2).clip(min=1e-8, max=1e16)
x_c = torch.log(x_c)
x_c = (x_c - self.log_fft_mean.to(x.device)) / self.log_fft_std.to(x.device)
# Return to expected input shape (b,c,h,w)
x_c = x_c.permute(0, 3, 1, 2).reshape(b, c * 2, h, w // 2 + 1)
return self.net(x_c)
@register_model
def register_wide_kernel_vgg(opt_net, opt):
""" return a ResNet 18 object
"""
return WideKernelVgg(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
vgg = WideKernelVgg()
vgg(torch.randn(1,3,256,256))

View File

@ -13,14 +13,14 @@ import torch
def main(): def main():
split_img = False split_img = False
opt = {} opt = {}
opt['n_thread'] = 4 opt['n_thread'] = 8
opt['compression_level'] = 90 # JPEG compression quality rating. opt['compression_level'] = 95 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed. # compression time. If read raw images during training, use 0 for faster IO speed.
opt['dest'] = 'file' opt['dest'] = 'file'
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new' opt['input_folder'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_v5' opt['save_folder'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked_pt2'
opt['crop_sz'] = [256, 512] # the size of each sub-image opt['crop_sz'] = [256, 512] # the size of each sub-image
opt['step'] = [256, 512] # step of the sliding crop window opt['step'] = [256, 512] # step of the sliding crop window
opt['exclusions'] = [[],[]] # image names matching these terms wont be included in the processing. opt['exclusions'] = [[],[]] # image names matching these terms wont be included in the processing.
@ -28,6 +28,8 @@ def main():
opt['resize_final_img'] = [1, .5] opt['resize_final_img'] = [1, .5]
opt['only_resize'] = False opt['only_resize'] = False
opt['vertical_split'] = False opt['vertical_split'] = False
opt['use_masking'] = True
opt['mask_path'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new_masks'
opt['input_image_max_size_before_being_halved'] = 5500 # As described, images larger than this dimensional size will be halved before anything else is done. opt['input_image_max_size_before_being_halved'] = 5500 # As described, images larger than this dimensional size will be halved before anything else is done.
# This helps prevent images from cameras with "false-megapixels" from polluting the dataset. # This helps prevent images from cameras with "false-megapixels" from polluting the dataset.
# False-megapixel=lots of noise at ultra-high res. # False-megapixel=lots of noise at ultra-high res.
@ -150,7 +152,7 @@ class TiledDataset(data.Dataset):
# Wrap in a tuple to align with split mode. # Wrap in a tuple to align with split mode.
return (self.get(index, False, False), None) return (self.get(index, False, False), None)
def get_for_scale(self, img, crop_sz, step, resize_factor, ref_resize_factor): def get_for_scale(self, img, mask, crop_sz, step, resize_factor, ref_resize_factor):
thres_sz = self.opt['thres_sz'] thres_sz = self.opt['thres_sz']
h, w, c = img.shape h, w, c = img.shape
@ -172,6 +174,15 @@ class TiledDataset(data.Dataset):
for y in w_space: for y in w_space:
index += 1 index += 1
crop_img = img[x:x + crop_sz, y:y + crop_sz, :] crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
if mask is not None:
def mask_map(inp):
mask_factor = 256 / (crop_sz * ref_resize_factor)
return int(inp * mask_factor)
crop_mask = mask[mask_map(x):mask_map(x+crop_sz),
mask_map(y):mask_map(y+crop_sz),
:]
if crop_mask.mean() < 255 / 2: # If at least 50% of the image isn't made up of the type of pixels we want to process, ignore this tile.
continue
# Center point needs to be resized by ref_resize_factor - since it is relative to the reference image. # Center point needs to be resized by ref_resize_factor - since it is relative to the reference image.
center_point = (int((x + crop_sz // 2) // ref_resize_factor), int((y + crop_sz // 2) // ref_resize_factor)) center_point = (int((x + crop_sz // 2) // ref_resize_factor), int((y + crop_sz // 2) // ref_resize_factor))
crop_img = np.ascontiguousarray(crop_img) crop_img = np.ascontiguousarray(crop_img)
@ -185,10 +196,11 @@ class TiledDataset(data.Dataset):
def get(self, index, split_mode, left_img): def get(self, index, split_mode, left_img):
path = self.images[index] path = self.images[index]
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img is None or len(img.shape) == 2: if img is None or len(img.shape) == 2:
return None return None
mask = cv2.imread(os.path.join(self.opt['mask_path'], os.path.basename(path) + ".png"), cv2.IMREAD_UNCHANGED) if self.opt['use_masking'] else None
h, w, c = img.shape h, w, c = img.shape
if max(h,w) > self.opt['input_image_max_size_before_being_halved']: if max(h,w) > self.opt['input_image_max_size_before_being_halved']:
@ -248,7 +260,7 @@ class TiledDataset(data.Dataset):
break break
if excluded: if excluded:
continue continue
results.extend(self.get_for_scale(img, crop_sz, step, resize_factor, ref_resize_factor)) results.extend(self.get_for_scale(img, mask, crop_sz, step, resize_factor, ref_resize_factor))
return results, path return results, path
def __len__(self): def __len__(self):

View File

@ -299,7 +299,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion_xstart.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -101,8 +101,8 @@ class ConfigurableStep(Module):
betas=(opt_config['beta1'], opt_config['beta2'])) betas=(opt_config['beta1'], opt_config['beta2']))
elif self.step_opt['optimizer'] == 'adamw': elif self.step_opt['optimizer'] == 'adamw':
opt = torch.optim.AdamW(list(optim_params.values()), opt = torch.optim.AdamW(list(optim_params.values()),
weight_decay=opt_config['weight_decay'], weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_config['beta1'], opt_config['beta2'])) betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
elif self.step_opt['optimizer'] == 'lars': elif self.step_opt['optimizer'] == 'lars':
from trainer.optimizers.larc import LARC from trainer.optimizers.larc import LARC
from trainer.optimizers.sgd import SGDNoBiasMomentum from trainer.optimizers.sgd import SGDNoBiasMomentum

View File

@ -4,25 +4,28 @@ import time
import argparse import argparse
import os import os
from torchvision.transforms import CenterCrop
from trainer.ExtensibleTrainer import ExtensibleTrainer
from utils import options as option from utils import options as option
import utils.util as util import utils.util as util
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
from models import create_model
from tqdm import tqdm from tqdm import tqdm
import torch import torch
import torchvision import torchvision
if __name__ == "__main__": if __name__ == "__main__":
bin_path = "f:\\binned" bin_path = "f:\\tmp\\binned"
good_path = "f:\\good" good_path = "f:\\tmp\\good"
os.makedirs(bin_path, exist_ok=True) os.makedirs(bin_path, exist_ok=True)
os.makedirs(good_path, exist_ok=True) os.makedirs(good_path, exist_ok=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/discriminator_filter.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/train_quality_detectors/train_resnet_jpeg.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
opt['dist'] = False opt['dist'] = False
@ -43,7 +46,7 @@ if __name__ == "__main__":
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
test_loaders.append(test_loader) test_loaders.append(test_loader)
model = create_model(opt) model = ExtensibleTrainer(opt)
fea_loss = 0 fea_loss = 0
for test_loader in test_loaders: for test_loader in test_loaders:
test_set_name = test_loader.dataset.opt['name'] test_set_name = test_loader.dataset.opt['name']
@ -55,19 +58,17 @@ if __name__ == "__main__":
tq = tqdm(test_loader) tq = tqdm(test_loader)
removed = 0 removed = 0
means = [] means = []
dataset_mean = -7.133 for k, data in enumerate(tq):
for data in tq: model.feed_data(data, k)
model.feed_data(data, need_GT=True)
model.test() model.test()
results = model.eval_state['discriminator_out'][0] results = torch.argmax(torch.nn.functional.softmax(model.eval_state['logits'][0], dim=-1), dim=1)
means.append(torch.mean(results).item())
print(sum(means)/len(means), torch.mean(results), torch.max(results), torch.min(results))
for i in range(results.shape[0]): for i in range(results.shape[0]):
#if results[i] < .8: if results[i] == 0:
# os.remove(data['GT_path'][i]) imname = osp.basename(data['HQ_path'][i])
# removed += 1 # For VERIFICATION:
imname = osp.basename(data['GT_path'][i]) #torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
if results[i]-dataset_mean > 1: # 4 REALZ:
torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname)) os.remove(data['HQ_path'][i])
removed += 1
print("Removed %i/%i images" % (removed, len(test_set))) print("Removed %i/%i images" % (removed, len(test_set)))