Various mods to support better jpeg image filtering

This commit is contained in:
James Betker 2021-06-25 13:16:15 -06:00
parent e7890dc0ba
commit a57ed8e960
8 changed files with 186 additions and 41 deletions

View File

@ -9,6 +9,8 @@ from io import BytesIO
# Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data)
from utils.util import opt_get
'''
if __name__ == '__main__':
import numpy as np
@ -23,13 +25,15 @@ if __name__ == '__main__':
class ImageCorruptor:
def __init__(self, opt):
self.opt = opt
self.reset_random()
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
self.cosine_bias = opt_get(opt, ['cosine_bias'], True)
if self.num_corrupts == 0:
return
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
self.reset_random()
else:
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
def reset_random(self):
if 'random_seed' in self.opt.keys():
@ -41,7 +45,10 @@ class ImageCorruptor:
# Return is on [0,1] with a bias towards 0.
def get_rand(self):
r = self.rand.random()
return 1 - cos(r * pi / 2)
if self.cosine_bias:
return 1 - cos(r * pi / 2)
else:
return r
def corrupt_images(self, imgs, return_entropy=False):
if self.num_corrupts == 0 and not self.fixed_corruptions:
@ -82,10 +89,10 @@ class ImageCorruptor:
img = (img // quant_div) * quant_div
img = img / 255
elif 'gaussian_blur' in aug:
img = cv2.GaussianBlur(img, (0,0), rand_val*1.5)
img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5)
elif 'motion_blur' in aug:
# Motion blur
intensity = self.blur_scale * rand_val * 3 + 1
intensity = self.blur_scale*rand_val * 3 + 1
angle = random.randint(0,360)
k = np.zeros((intensity, intensity), dtype=np.float32)
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)

View File

@ -1,3 +1,4 @@
import functools
import glob
import itertools
import random
@ -11,7 +12,7 @@ import os
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import Normalize
from torchvision.transforms import Normalize, CenterCrop
from tqdm import tqdm
from data import util
@ -21,10 +22,19 @@ from data.image_label_parser import VsNetImageLabeler
from utils.util import opt_get
def ndarray_center_crop(crop, img):
y, x, c = img.shape
startx = x // 2 - crop // 2
starty = y // 2 - crop // 2
return img[starty:starty + crop, startx:startx + crop, :]
class ImageFolderDataset:
def __init__(self, opt):
self.opt = opt
self.corruptor = ImageCorruptor(opt)
if 'center_crop_hq_sz' in opt.keys():
self.center_crop = functools.partial(ndarray_center_crop, opt['center_crop_hq_sz'])
self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None
self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1
self.scale = opt['scale']
@ -132,6 +142,8 @@ class ImageFolderDataset:
def __getitem__(self, item):
hq = util.read_img(None, self.image_paths[item], rgb=True)
if hasattr(self, 'center_crop'):
hq = self.center_crop(hq)
if not self.disable_flip and random.random() < .5:
hq = hq[:, ::-1, :]
@ -223,25 +235,41 @@ class ImageFolderDataset:
if __name__ == '__main__':
opt = {
'name': 'amalgam',
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked'],
'weights': [1],
'target_size': 256,
'force_multiple': 1,
'scale': 2,
'corrupt_before_downsize': True,
'fetch_alt_image': False,
'disable_flip': True,
'fixed_corruptions': [ 'jpeg-broad', 'gaussian_blur' ],
'fixed_corruptions': [ 'jpeg-medium' ],
'num_corrupts_per_image': 0,
'corruption_blur_scale': 0
}
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0)
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=4, batch_size=64)
import os
output_path = 'E:\\4k6k\\datasets\\ns_images\\128_unsupervised'
output_path = 'F:\\tmp'
os.makedirs(output_path, exist_ok=True)
res = []
for i, d in tqdm(enumerate(ds)):
lq = d['lq']
#torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png')
'''
x = d['hq']
b,c,h,w = x.shape
x_c = x.view(c*b, h, w)
x_c = torch.view_as_real(torch.fft.rfft(x_c))
# Log-normalize spectrogram
x_c = (x_c.abs() ** 2).clip(min=1e-8, max=1e16)
x_c = torch.log(x_c)
res.append(x_c)
if i % 100 == 99:
stacked = torch.cat(res, dim=0)
print(stacked.mean(dim=[0,1,2]), stacked.std(dim=[0,1,2]))
'''
for k, v in d.items():
if isinstance(v, torch.Tensor) and len(v.shape) >= 3:
os.makedirs(f'{output_path}\\{k}', exist_ok=True)
torchvision.utils.save_image(v, f'{output_path}\\{k}\\{i}.png')
if i >= 200000:
break
break

View File

@ -0,0 +1,11 @@
from torchvision.models import vgg16
from trainer.networks import register_model
from utils.util import opt_get
@register_model
def register_torch_vgg16(opt_net, opt):
""" return a ResNet 18 object
"""
return vgg16(**opt_get(opt_net, ['kwargs'], {}))

View File

@ -0,0 +1,86 @@
import torch
import torch.nn as nn
from trainer.networks import register_model
from utils.util import opt_get
class WideKernelVgg(nn.Module):
def __init__(self, nf=64, num_classes=2):
super().__init__()
self.net = nn.Sequential(
# [64, 128, 128]
nn.Conv2d(6, nf, 7, 1, 3, bias=True),
nn.BatchNorm2d(nf, affine=True),
nn.ReLU(),
nn.Conv2d(nf, nf, 7, 1, 3, bias=False),
nn.BatchNorm2d(nf, affine=True),
nn.ReLU(),
nn.Conv2d(nf, nf, 5, 2, 2, bias=False),
nn.BatchNorm2d(nf, affine=True),
nn.ReLU(),
# [64, 64, 64]
nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False),
nn.BatchNorm2d(nf * 2, affine=True),
nn.ReLU(),
nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(nf * 2, affine=True),
nn.ReLU(),
# [128, 32, 32]
nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False),
nn.BatchNorm2d(nf * 4, affine=True),
nn.ReLU(),
nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(nf * 4, affine=True),
nn.ReLU(),
# [256, 16, 16]
nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False),
nn.BatchNorm2d(nf * 8, affine=True),
nn.ReLU(),
nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(nf * 8, affine=True),
nn.ReLU(),
# [512, 8, 8]
nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False),
nn.BatchNorm2d(nf * 8, affine=True),
nn.ReLU(),
nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(nf * 8, affine=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(nf * 8 * 4 * 2, 100),
nn.ReLU(),
nn.Linear(100, num_classes)
)
# These normalization constants should be derived experimentally.
self.log_fft_mean = torch.tensor([-3.5184, -4.071]).view(1,1,1,2)
self.log_fft_std = torch.tensor([3.1660, 3.8042]).view(1,1,1,2)
def forward(self, x):
b,c,h,w = x.shape
x_c = x.view(c*b, h, w)
x_c = torch.view_as_real(torch.fft.rfft(x_c))
# Log-normalize spectrogram
x_c = (x_c.abs() ** 2).clip(min=1e-8, max=1e16)
x_c = torch.log(x_c)
x_c = (x_c - self.log_fft_mean.to(x.device)) / self.log_fft_std.to(x.device)
# Return to expected input shape (b,c,h,w)
x_c = x_c.permute(0, 3, 1, 2).reshape(b, c * 2, h, w // 2 + 1)
return self.net(x_c)
@register_model
def register_wide_kernel_vgg(opt_net, opt):
""" return a ResNet 18 object
"""
return WideKernelVgg(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
vgg = WideKernelVgg()
vgg(torch.randn(1,3,256,256))

View File

@ -13,14 +13,14 @@ import torch
def main():
split_img = False
opt = {}
opt['n_thread'] = 4
opt['compression_level'] = 90 # JPEG compression quality rating.
opt['n_thread'] = 8
opt['compression_level'] = 95 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
opt['dest'] = 'file'
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_v5'
opt['input_folder'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'
opt['save_folder'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked_pt2'
opt['crop_sz'] = [256, 512] # the size of each sub-image
opt['step'] = [256, 512] # step of the sliding crop window
opt['exclusions'] = [[],[]] # image names matching these terms wont be included in the processing.
@ -28,6 +28,8 @@ def main():
opt['resize_final_img'] = [1, .5]
opt['only_resize'] = False
opt['vertical_split'] = False
opt['use_masking'] = True
opt['mask_path'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new_masks'
opt['input_image_max_size_before_being_halved'] = 5500 # As described, images larger than this dimensional size will be halved before anything else is done.
# This helps prevent images from cameras with "false-megapixels" from polluting the dataset.
# False-megapixel=lots of noise at ultra-high res.
@ -150,7 +152,7 @@ class TiledDataset(data.Dataset):
# Wrap in a tuple to align with split mode.
return (self.get(index, False, False), None)
def get_for_scale(self, img, crop_sz, step, resize_factor, ref_resize_factor):
def get_for_scale(self, img, mask, crop_sz, step, resize_factor, ref_resize_factor):
thres_sz = self.opt['thres_sz']
h, w, c = img.shape
@ -172,6 +174,15 @@ class TiledDataset(data.Dataset):
for y in w_space:
index += 1
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
if mask is not None:
def mask_map(inp):
mask_factor = 256 / (crop_sz * ref_resize_factor)
return int(inp * mask_factor)
crop_mask = mask[mask_map(x):mask_map(x+crop_sz),
mask_map(y):mask_map(y+crop_sz),
:]
if crop_mask.mean() < 255 / 2: # If at least 50% of the image isn't made up of the type of pixels we want to process, ignore this tile.
continue
# Center point needs to be resized by ref_resize_factor - since it is relative to the reference image.
center_point = (int((x + crop_sz // 2) // ref_resize_factor), int((y + crop_sz // 2) // ref_resize_factor))
crop_img = np.ascontiguousarray(crop_img)
@ -185,10 +196,11 @@ class TiledDataset(data.Dataset):
def get(self, index, split_mode, left_img):
path = self.images[index]
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img is None or len(img.shape) == 2:
return None
mask = cv2.imread(os.path.join(self.opt['mask_path'], os.path.basename(path) + ".png"), cv2.IMREAD_UNCHANGED) if self.opt['use_masking'] else None
h, w, c = img.shape
if max(h,w) > self.opt['input_image_max_size_before_being_halved']:
@ -248,7 +260,7 @@ class TiledDataset(data.Dataset):
break
if excluded:
continue
results.extend(self.get_for_scale(img, crop_sz, step, resize_factor, ref_resize_factor))
results.extend(self.get_for_scale(img, mask, crop_sz, step, resize_factor, ref_resize_factor))
return results, path
def __len__(self):

View File

@ -299,7 +299,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion_xstart.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -101,8 +101,8 @@ class ConfigurableStep(Module):
betas=(opt_config['beta1'], opt_config['beta2']))
elif self.step_opt['optimizer'] == 'adamw':
opt = torch.optim.AdamW(list(optim_params.values()),
weight_decay=opt_config['weight_decay'],
betas=(opt_config['beta1'], opt_config['beta2']))
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
elif self.step_opt['optimizer'] == 'lars':
from trainer.optimizers.larc import LARC
from trainer.optimizers.sgd import SGDNoBiasMomentum

View File

@ -4,25 +4,28 @@ import time
import argparse
import os
from torchvision.transforms import CenterCrop
from trainer.ExtensibleTrainer import ExtensibleTrainer
from utils import options as option
import utils.util as util
from data import create_dataset, create_dataloader
from models import create_model
from tqdm import tqdm
import torch
import torchvision
if __name__ == "__main__":
bin_path = "f:\\binned"
good_path = "f:\\good"
bin_path = "f:\\tmp\\binned"
good_path = "f:\\tmp\\good"
os.makedirs(bin_path, exist_ok=True)
os.makedirs(good_path, exist_ok=True)
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/discriminator_filter.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/train_quality_detectors/train_resnet_jpeg.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
opt['dist'] = False
@ -43,7 +46,7 @@ if __name__ == "__main__":
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
test_loaders.append(test_loader)
model = create_model(opt)
model = ExtensibleTrainer(opt)
fea_loss = 0
for test_loader in test_loaders:
test_set_name = test_loader.dataset.opt['name']
@ -55,19 +58,17 @@ if __name__ == "__main__":
tq = tqdm(test_loader)
removed = 0
means = []
dataset_mean = -7.133
for data in tq:
model.feed_data(data, need_GT=True)
for k, data in enumerate(tq):
model.feed_data(data, k)
model.test()
results = model.eval_state['discriminator_out'][0]
means.append(torch.mean(results).item())
print(sum(means)/len(means), torch.mean(results), torch.max(results), torch.min(results))
results = torch.argmax(torch.nn.functional.softmax(model.eval_state['logits'][0], dim=-1), dim=1)
for i in range(results.shape[0]):
#if results[i] < .8:
# os.remove(data['GT_path'][i])
# removed += 1
imname = osp.basename(data['GT_path'][i])
if results[i]-dataset_mean > 1:
torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
if results[i] == 0:
imname = osp.basename(data['HQ_path'][i])
# For VERIFICATION:
#torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
# 4 REALZ:
os.remove(data['HQ_path'][i])
removed += 1
print("Removed %i/%i images" % (removed, len(test_set)))