forked from mrq/DL-Art-School
Various mods to support better jpeg image filtering
This commit is contained in:
parent
e7890dc0ba
commit
a57ed8e960
|
@ -9,6 +9,8 @@ from io import BytesIO
|
|||
|
||||
|
||||
# Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data)
|
||||
from utils.util import opt_get
|
||||
|
||||
'''
|
||||
if __name__ == '__main__':
|
||||
import numpy as np
|
||||
|
@ -23,13 +25,15 @@ if __name__ == '__main__':
|
|||
class ImageCorruptor:
|
||||
def __init__(self, opt):
|
||||
self.opt = opt
|
||||
self.reset_random()
|
||||
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
|
||||
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
|
||||
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
|
||||
self.cosine_bias = opt_get(opt, ['cosine_bias'], True)
|
||||
if self.num_corrupts == 0:
|
||||
return
|
||||
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
|
||||
self.reset_random()
|
||||
else:
|
||||
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
|
||||
|
||||
def reset_random(self):
|
||||
if 'random_seed' in self.opt.keys():
|
||||
|
@ -41,7 +45,10 @@ class ImageCorruptor:
|
|||
# Return is on [0,1] with a bias towards 0.
|
||||
def get_rand(self):
|
||||
r = self.rand.random()
|
||||
return 1 - cos(r * pi / 2)
|
||||
if self.cosine_bias:
|
||||
return 1 - cos(r * pi / 2)
|
||||
else:
|
||||
return r
|
||||
|
||||
def corrupt_images(self, imgs, return_entropy=False):
|
||||
if self.num_corrupts == 0 and not self.fixed_corruptions:
|
||||
|
@ -82,10 +89,10 @@ class ImageCorruptor:
|
|||
img = (img // quant_div) * quant_div
|
||||
img = img / 255
|
||||
elif 'gaussian_blur' in aug:
|
||||
img = cv2.GaussianBlur(img, (0,0), rand_val*1.5)
|
||||
img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5)
|
||||
elif 'motion_blur' in aug:
|
||||
# Motion blur
|
||||
intensity = self.blur_scale * rand_val * 3 + 1
|
||||
intensity = self.blur_scale*rand_val * 3 + 1
|
||||
angle = random.randint(0,360)
|
||||
k = np.zeros((intensity, intensity), dtype=np.float32)
|
||||
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import functools
|
||||
import glob
|
||||
import itertools
|
||||
import random
|
||||
|
@ -11,7 +12,7 @@ import os
|
|||
|
||||
import torchvision
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Normalize
|
||||
from torchvision.transforms import Normalize, CenterCrop
|
||||
from tqdm import tqdm
|
||||
|
||||
from data import util
|
||||
|
@ -21,10 +22,19 @@ from data.image_label_parser import VsNetImageLabeler
|
|||
from utils.util import opt_get
|
||||
|
||||
|
||||
def ndarray_center_crop(crop, img):
|
||||
y, x, c = img.shape
|
||||
startx = x // 2 - crop // 2
|
||||
starty = y // 2 - crop // 2
|
||||
return img[starty:starty + crop, startx:startx + crop, :]
|
||||
|
||||
|
||||
class ImageFolderDataset:
|
||||
def __init__(self, opt):
|
||||
self.opt = opt
|
||||
self.corruptor = ImageCorruptor(opt)
|
||||
if 'center_crop_hq_sz' in opt.keys():
|
||||
self.center_crop = functools.partial(ndarray_center_crop, opt['center_crop_hq_sz'])
|
||||
self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None
|
||||
self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1
|
||||
self.scale = opt['scale']
|
||||
|
@ -132,6 +142,8 @@ class ImageFolderDataset:
|
|||
|
||||
def __getitem__(self, item):
|
||||
hq = util.read_img(None, self.image_paths[item], rgb=True)
|
||||
if hasattr(self, 'center_crop'):
|
||||
hq = self.center_crop(hq)
|
||||
if not self.disable_flip and random.random() < .5:
|
||||
hq = hq[:, ::-1, :]
|
||||
|
||||
|
@ -223,25 +235,41 @@ class ImageFolderDataset:
|
|||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'name': 'amalgam',
|
||||
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked'],
|
||||
'weights': [1],
|
||||
'target_size': 256,
|
||||
'force_multiple': 1,
|
||||
'scale': 2,
|
||||
'corrupt_before_downsize': True,
|
||||
'fetch_alt_image': False,
|
||||
'disable_flip': True,
|
||||
'fixed_corruptions': [ 'jpeg-broad', 'gaussian_blur' ],
|
||||
'fixed_corruptions': [ 'jpeg-medium' ],
|
||||
'num_corrupts_per_image': 0,
|
||||
'corruption_blur_scale': 0
|
||||
}
|
||||
|
||||
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0)
|
||||
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=4, batch_size=64)
|
||||
import os
|
||||
output_path = 'E:\\4k6k\\datasets\\ns_images\\128_unsupervised'
|
||||
output_path = 'F:\\tmp'
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
res = []
|
||||
for i, d in tqdm(enumerate(ds)):
|
||||
lq = d['lq']
|
||||
#torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png')
|
||||
'''
|
||||
x = d['hq']
|
||||
b,c,h,w = x.shape
|
||||
x_c = x.view(c*b, h, w)
|
||||
x_c = torch.view_as_real(torch.fft.rfft(x_c))
|
||||
# Log-normalize spectrogram
|
||||
x_c = (x_c.abs() ** 2).clip(min=1e-8, max=1e16)
|
||||
x_c = torch.log(x_c)
|
||||
res.append(x_c)
|
||||
if i % 100 == 99:
|
||||
stacked = torch.cat(res, dim=0)
|
||||
print(stacked.mean(dim=[0,1,2]), stacked.std(dim=[0,1,2]))
|
||||
'''
|
||||
|
||||
for k, v in d.items():
|
||||
if isinstance(v, torch.Tensor) and len(v.shape) >= 3:
|
||||
os.makedirs(f'{output_path}\\{k}', exist_ok=True)
|
||||
torchvision.utils.save_image(v, f'{output_path}\\{k}\\{i}.png')
|
||||
if i >= 200000:
|
||||
break
|
||||
break
|
||||
|
|
11
codes/models/classifiers/torch_models.py
Normal file
11
codes/models/classifiers/torch_models.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from torchvision.models import vgg16
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
@register_model
|
||||
def register_torch_vgg16(opt_net, opt):
|
||||
""" return a ResNet 18 object
|
||||
"""
|
||||
return vgg16(**opt_get(opt_net, ['kwargs'], {}))
|
86
codes/models/classifiers/wide_kernel_vgg.py
Normal file
86
codes/models/classifiers/wide_kernel_vgg.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class WideKernelVgg(nn.Module):
|
||||
def __init__(self, nf=64, num_classes=2):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
# [64, 128, 128]
|
||||
nn.Conv2d(6, nf, 7, 1, 3, bias=True),
|
||||
nn.BatchNorm2d(nf, affine=True),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(nf, nf, 7, 1, 3, bias=False),
|
||||
nn.BatchNorm2d(nf, affine=True),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(nf, nf, 5, 2, 2, bias=False),
|
||||
nn.BatchNorm2d(nf, affine=True),
|
||||
nn.ReLU(),
|
||||
# [64, 64, 64]
|
||||
nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False),
|
||||
nn.BatchNorm2d(nf * 2, affine=True),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(nf * 2, affine=True),
|
||||
nn.ReLU(),
|
||||
# [128, 32, 32]
|
||||
nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False),
|
||||
nn.BatchNorm2d(nf * 4, affine=True),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(nf * 4, affine=True),
|
||||
nn.ReLU(),
|
||||
# [256, 16, 16]
|
||||
nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False),
|
||||
nn.BatchNorm2d(nf * 8, affine=True),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(nf * 8, affine=True),
|
||||
nn.ReLU(),
|
||||
# [512, 8, 8]
|
||||
nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False),
|
||||
nn.BatchNorm2d(nf * 8, affine=True),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(nf * 8, affine=True),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2),
|
||||
nn.Flatten(),
|
||||
nn.Linear(nf * 8 * 4 * 2, 100),
|
||||
nn.ReLU(),
|
||||
nn.Linear(100, num_classes)
|
||||
)
|
||||
|
||||
# These normalization constants should be derived experimentally.
|
||||
self.log_fft_mean = torch.tensor([-3.5184, -4.071]).view(1,1,1,2)
|
||||
self.log_fft_std = torch.tensor([3.1660, 3.8042]).view(1,1,1,2)
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x_c = x.view(c*b, h, w)
|
||||
x_c = torch.view_as_real(torch.fft.rfft(x_c))
|
||||
|
||||
# Log-normalize spectrogram
|
||||
x_c = (x_c.abs() ** 2).clip(min=1e-8, max=1e16)
|
||||
x_c = torch.log(x_c)
|
||||
x_c = (x_c - self.log_fft_mean.to(x.device)) / self.log_fft_std.to(x.device)
|
||||
|
||||
# Return to expected input shape (b,c,h,w)
|
||||
x_c = x_c.permute(0, 3, 1, 2).reshape(b, c * 2, h, w // 2 + 1)
|
||||
|
||||
return self.net(x_c)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_wide_kernel_vgg(opt_net, opt):
|
||||
""" return a ResNet 18 object
|
||||
"""
|
||||
return WideKernelVgg(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
vgg = WideKernelVgg()
|
||||
vgg(torch.randn(1,3,256,256))
|
|
@ -13,14 +13,14 @@ import torch
|
|||
def main():
|
||||
split_img = False
|
||||
opt = {}
|
||||
opt['n_thread'] = 4
|
||||
opt['compression_level'] = 90 # JPEG compression quality rating.
|
||||
opt['n_thread'] = 8
|
||||
opt['compression_level'] = 95 # JPEG compression quality rating.
|
||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||
|
||||
opt['dest'] = 'file'
|
||||
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_v5'
|
||||
opt['input_folder'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'
|
||||
opt['save_folder'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked_pt2'
|
||||
opt['crop_sz'] = [256, 512] # the size of each sub-image
|
||||
opt['step'] = [256, 512] # step of the sliding crop window
|
||||
opt['exclusions'] = [[],[]] # image names matching these terms wont be included in the processing.
|
||||
|
@ -28,6 +28,8 @@ def main():
|
|||
opt['resize_final_img'] = [1, .5]
|
||||
opt['only_resize'] = False
|
||||
opt['vertical_split'] = False
|
||||
opt['use_masking'] = True
|
||||
opt['mask_path'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new_masks'
|
||||
opt['input_image_max_size_before_being_halved'] = 5500 # As described, images larger than this dimensional size will be halved before anything else is done.
|
||||
# This helps prevent images from cameras with "false-megapixels" from polluting the dataset.
|
||||
# False-megapixel=lots of noise at ultra-high res.
|
||||
|
@ -150,7 +152,7 @@ class TiledDataset(data.Dataset):
|
|||
# Wrap in a tuple to align with split mode.
|
||||
return (self.get(index, False, False), None)
|
||||
|
||||
def get_for_scale(self, img, crop_sz, step, resize_factor, ref_resize_factor):
|
||||
def get_for_scale(self, img, mask, crop_sz, step, resize_factor, ref_resize_factor):
|
||||
thres_sz = self.opt['thres_sz']
|
||||
h, w, c = img.shape
|
||||
|
||||
|
@ -172,6 +174,15 @@ class TiledDataset(data.Dataset):
|
|||
for y in w_space:
|
||||
index += 1
|
||||
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
|
||||
if mask is not None:
|
||||
def mask_map(inp):
|
||||
mask_factor = 256 / (crop_sz * ref_resize_factor)
|
||||
return int(inp * mask_factor)
|
||||
crop_mask = mask[mask_map(x):mask_map(x+crop_sz),
|
||||
mask_map(y):mask_map(y+crop_sz),
|
||||
:]
|
||||
if crop_mask.mean() < 255 / 2: # If at least 50% of the image isn't made up of the type of pixels we want to process, ignore this tile.
|
||||
continue
|
||||
# Center point needs to be resized by ref_resize_factor - since it is relative to the reference image.
|
||||
center_point = (int((x + crop_sz // 2) // ref_resize_factor), int((y + crop_sz // 2) // ref_resize_factor))
|
||||
crop_img = np.ascontiguousarray(crop_img)
|
||||
|
@ -185,10 +196,11 @@ class TiledDataset(data.Dataset):
|
|||
def get(self, index, split_mode, left_img):
|
||||
path = self.images[index]
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if img is None or len(img.shape) == 2:
|
||||
return None
|
||||
|
||||
mask = cv2.imread(os.path.join(self.opt['mask_path'], os.path.basename(path) + ".png"), cv2.IMREAD_UNCHANGED) if self.opt['use_masking'] else None
|
||||
|
||||
h, w, c = img.shape
|
||||
|
||||
if max(h,w) > self.opt['input_image_max_size_before_being_halved']:
|
||||
|
@ -248,7 +260,7 @@ class TiledDataset(data.Dataset):
|
|||
break
|
||||
if excluded:
|
||||
continue
|
||||
results.extend(self.get_for_scale(img, crop_sz, step, resize_factor, ref_resize_factor))
|
||||
results.extend(self.get_for_scale(img, mask, crop_sz, step, resize_factor, ref_resize_factor))
|
||||
return results, path
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -299,7 +299,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion_xstart.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -101,8 +101,8 @@ class ConfigurableStep(Module):
|
|||
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||
elif self.step_opt['optimizer'] == 'adamw':
|
||||
opt = torch.optim.AdamW(list(optim_params.values()),
|
||||
weight_decay=opt_config['weight_decay'],
|
||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||
elif self.step_opt['optimizer'] == 'lars':
|
||||
from trainer.optimizers.larc import LARC
|
||||
from trainer.optimizers.sgd import SGDNoBiasMomentum
|
||||
|
|
|
@ -4,25 +4,28 @@ import time
|
|||
import argparse
|
||||
|
||||
import os
|
||||
|
||||
from torchvision.transforms import CenterCrop
|
||||
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from utils import options as option
|
||||
import utils.util as util
|
||||
from data import create_dataset, create_dataloader
|
||||
from models import create_model
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bin_path = "f:\\binned"
|
||||
good_path = "f:\\good"
|
||||
bin_path = "f:\\tmp\\binned"
|
||||
good_path = "f:\\tmp\\good"
|
||||
os.makedirs(bin_path, exist_ok=True)
|
||||
os.makedirs(good_path, exist_ok=True)
|
||||
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/discriminator_filter.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/train_quality_detectors/train_resnet_jpeg.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
opt['dist'] = False
|
||||
|
@ -43,7 +46,7 @@ if __name__ == "__main__":
|
|||
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||
test_loaders.append(test_loader)
|
||||
|
||||
model = create_model(opt)
|
||||
model = ExtensibleTrainer(opt)
|
||||
fea_loss = 0
|
||||
for test_loader in test_loaders:
|
||||
test_set_name = test_loader.dataset.opt['name']
|
||||
|
@ -55,19 +58,17 @@ if __name__ == "__main__":
|
|||
tq = tqdm(test_loader)
|
||||
removed = 0
|
||||
means = []
|
||||
dataset_mean = -7.133
|
||||
for data in tq:
|
||||
model.feed_data(data, need_GT=True)
|
||||
for k, data in enumerate(tq):
|
||||
model.feed_data(data, k)
|
||||
model.test()
|
||||
results = model.eval_state['discriminator_out'][0]
|
||||
means.append(torch.mean(results).item())
|
||||
print(sum(means)/len(means), torch.mean(results), torch.max(results), torch.min(results))
|
||||
results = torch.argmax(torch.nn.functional.softmax(model.eval_state['logits'][0], dim=-1), dim=1)
|
||||
for i in range(results.shape[0]):
|
||||
#if results[i] < .8:
|
||||
# os.remove(data['GT_path'][i])
|
||||
# removed += 1
|
||||
imname = osp.basename(data['GT_path'][i])
|
||||
if results[i]-dataset_mean > 1:
|
||||
torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
|
||||
if results[i] == 0:
|
||||
imname = osp.basename(data['HQ_path'][i])
|
||||
# For VERIFICATION:
|
||||
#torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
|
||||
# 4 REALZ:
|
||||
os.remove(data['HQ_path'][i])
|
||||
removed += 1
|
||||
|
||||
print("Removed %i/%i images" % (removed, len(test_set)))
|
Loading…
Reference in New Issue
Block a user