forked from mrq/DL-Art-School
Various mods to support better jpeg image filtering
This commit is contained in:
parent
e7890dc0ba
commit
a57ed8e960
|
@ -9,6 +9,8 @@ from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
# Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data)
|
# Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data)
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
'''
|
'''
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -23,13 +25,15 @@ if __name__ == '__main__':
|
||||||
class ImageCorruptor:
|
class ImageCorruptor:
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
|
self.reset_random()
|
||||||
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
|
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
|
||||||
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
|
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
|
||||||
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
|
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
|
||||||
|
self.cosine_bias = opt_get(opt, ['cosine_bias'], True)
|
||||||
if self.num_corrupts == 0:
|
if self.num_corrupts == 0:
|
||||||
return
|
return
|
||||||
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
|
else:
|
||||||
self.reset_random()
|
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
|
||||||
|
|
||||||
def reset_random(self):
|
def reset_random(self):
|
||||||
if 'random_seed' in self.opt.keys():
|
if 'random_seed' in self.opt.keys():
|
||||||
|
@ -41,7 +45,10 @@ class ImageCorruptor:
|
||||||
# Return is on [0,1] with a bias towards 0.
|
# Return is on [0,1] with a bias towards 0.
|
||||||
def get_rand(self):
|
def get_rand(self):
|
||||||
r = self.rand.random()
|
r = self.rand.random()
|
||||||
return 1 - cos(r * pi / 2)
|
if self.cosine_bias:
|
||||||
|
return 1 - cos(r * pi / 2)
|
||||||
|
else:
|
||||||
|
return r
|
||||||
|
|
||||||
def corrupt_images(self, imgs, return_entropy=False):
|
def corrupt_images(self, imgs, return_entropy=False):
|
||||||
if self.num_corrupts == 0 and not self.fixed_corruptions:
|
if self.num_corrupts == 0 and not self.fixed_corruptions:
|
||||||
|
@ -82,10 +89,10 @@ class ImageCorruptor:
|
||||||
img = (img // quant_div) * quant_div
|
img = (img // quant_div) * quant_div
|
||||||
img = img / 255
|
img = img / 255
|
||||||
elif 'gaussian_blur' in aug:
|
elif 'gaussian_blur' in aug:
|
||||||
img = cv2.GaussianBlur(img, (0,0), rand_val*1.5)
|
img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5)
|
||||||
elif 'motion_blur' in aug:
|
elif 'motion_blur' in aug:
|
||||||
# Motion blur
|
# Motion blur
|
||||||
intensity = self.blur_scale * rand_val * 3 + 1
|
intensity = self.blur_scale*rand_val * 3 + 1
|
||||||
angle = random.randint(0,360)
|
angle = random.randint(0,360)
|
||||||
k = np.zeros((intensity, intensity), dtype=np.float32)
|
k = np.zeros((intensity, intensity), dtype=np.float32)
|
||||||
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)
|
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import functools
|
||||||
import glob
|
import glob
|
||||||
import itertools
|
import itertools
|
||||||
import random
|
import random
|
||||||
|
@ -11,7 +12,7 @@ import os
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision.transforms import Normalize
|
from torchvision.transforms import Normalize, CenterCrop
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from data import util
|
from data import util
|
||||||
|
@ -21,10 +22,19 @@ from data.image_label_parser import VsNetImageLabeler
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
def ndarray_center_crop(crop, img):
|
||||||
|
y, x, c = img.shape
|
||||||
|
startx = x // 2 - crop // 2
|
||||||
|
starty = y // 2 - crop // 2
|
||||||
|
return img[starty:starty + crop, startx:startx + crop, :]
|
||||||
|
|
||||||
|
|
||||||
class ImageFolderDataset:
|
class ImageFolderDataset:
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.corruptor = ImageCorruptor(opt)
|
self.corruptor = ImageCorruptor(opt)
|
||||||
|
if 'center_crop_hq_sz' in opt.keys():
|
||||||
|
self.center_crop = functools.partial(ndarray_center_crop, opt['center_crop_hq_sz'])
|
||||||
self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None
|
self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None
|
||||||
self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1
|
self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1
|
||||||
self.scale = opt['scale']
|
self.scale = opt['scale']
|
||||||
|
@ -132,6 +142,8 @@ class ImageFolderDataset:
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
hq = util.read_img(None, self.image_paths[item], rgb=True)
|
hq = util.read_img(None, self.image_paths[item], rgb=True)
|
||||||
|
if hasattr(self, 'center_crop'):
|
||||||
|
hq = self.center_crop(hq)
|
||||||
if not self.disable_flip and random.random() < .5:
|
if not self.disable_flip and random.random() < .5:
|
||||||
hq = hq[:, ::-1, :]
|
hq = hq[:, ::-1, :]
|
||||||
|
|
||||||
|
@ -223,25 +235,41 @@ class ImageFolderDataset:
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
opt = {
|
opt = {
|
||||||
'name': 'amalgam',
|
'name': 'amalgam',
|
||||||
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked'],
|
||||||
'weights': [1],
|
'weights': [1],
|
||||||
'target_size': 256,
|
'target_size': 256,
|
||||||
'force_multiple': 1,
|
|
||||||
'scale': 2,
|
'scale': 2,
|
||||||
'corrupt_before_downsize': True,
|
'corrupt_before_downsize': True,
|
||||||
'fetch_alt_image': False,
|
'fetch_alt_image': False,
|
||||||
'disable_flip': True,
|
'disable_flip': True,
|
||||||
'fixed_corruptions': [ 'jpeg-broad', 'gaussian_blur' ],
|
'fixed_corruptions': [ 'jpeg-medium' ],
|
||||||
'num_corrupts_per_image': 0,
|
'num_corrupts_per_image': 0,
|
||||||
'corruption_blur_scale': 0
|
'corruption_blur_scale': 0
|
||||||
}
|
}
|
||||||
|
|
||||||
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0)
|
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=4, batch_size=64)
|
||||||
import os
|
import os
|
||||||
output_path = 'E:\\4k6k\\datasets\\ns_images\\128_unsupervised'
|
output_path = 'F:\\tmp'
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
res = []
|
||||||
for i, d in tqdm(enumerate(ds)):
|
for i, d in tqdm(enumerate(ds)):
|
||||||
lq = d['lq']
|
'''
|
||||||
#torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png')
|
x = d['hq']
|
||||||
|
b,c,h,w = x.shape
|
||||||
|
x_c = x.view(c*b, h, w)
|
||||||
|
x_c = torch.view_as_real(torch.fft.rfft(x_c))
|
||||||
|
# Log-normalize spectrogram
|
||||||
|
x_c = (x_c.abs() ** 2).clip(min=1e-8, max=1e16)
|
||||||
|
x_c = torch.log(x_c)
|
||||||
|
res.append(x_c)
|
||||||
|
if i % 100 == 99:
|
||||||
|
stacked = torch.cat(res, dim=0)
|
||||||
|
print(stacked.mean(dim=[0,1,2]), stacked.std(dim=[0,1,2]))
|
||||||
|
'''
|
||||||
|
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, torch.Tensor) and len(v.shape) >= 3:
|
||||||
|
os.makedirs(f'{output_path}\\{k}', exist_ok=True)
|
||||||
|
torchvision.utils.save_image(v, f'{output_path}\\{k}\\{i}.png')
|
||||||
if i >= 200000:
|
if i >= 200000:
|
||||||
break
|
break
|
||||||
|
|
11
codes/models/classifiers/torch_models.py
Normal file
11
codes/models/classifiers/torch_models.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
from torchvision.models import vgg16
|
||||||
|
|
||||||
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_torch_vgg16(opt_net, opt):
|
||||||
|
""" return a ResNet 18 object
|
||||||
|
"""
|
||||||
|
return vgg16(**opt_get(opt_net, ['kwargs'], {}))
|
86
codes/models/classifiers/wide_kernel_vgg.py
Normal file
86
codes/models/classifiers/wide_kernel_vgg.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
class WideKernelVgg(nn.Module):
|
||||||
|
def __init__(self, nf=64, num_classes=2):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
# [64, 128, 128]
|
||||||
|
nn.Conv2d(6, nf, 7, 1, 3, bias=True),
|
||||||
|
nn.BatchNorm2d(nf, affine=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(nf, nf, 7, 1, 3, bias=False),
|
||||||
|
nn.BatchNorm2d(nf, affine=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(nf, nf, 5, 2, 2, bias=False),
|
||||||
|
nn.BatchNorm2d(nf, affine=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
# [64, 64, 64]
|
||||||
|
nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(nf * 2, affine=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(nf * 2, affine=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
# [128, 32, 32]
|
||||||
|
nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(nf * 4, affine=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(nf * 4, affine=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
# [256, 16, 16]
|
||||||
|
nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(nf * 8, affine=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(nf * 8, affine=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
# [512, 8, 8]
|
||||||
|
nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(nf * 8, affine=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(nf * 8, affine=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.MaxPool2d(kernel_size=2),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(nf * 8 * 4 * 2, 100),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(100, num_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
# These normalization constants should be derived experimentally.
|
||||||
|
self.log_fft_mean = torch.tensor([-3.5184, -4.071]).view(1,1,1,2)
|
||||||
|
self.log_fft_std = torch.tensor([3.1660, 3.8042]).view(1,1,1,2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b,c,h,w = x.shape
|
||||||
|
x_c = x.view(c*b, h, w)
|
||||||
|
x_c = torch.view_as_real(torch.fft.rfft(x_c))
|
||||||
|
|
||||||
|
# Log-normalize spectrogram
|
||||||
|
x_c = (x_c.abs() ** 2).clip(min=1e-8, max=1e16)
|
||||||
|
x_c = torch.log(x_c)
|
||||||
|
x_c = (x_c - self.log_fft_mean.to(x.device)) / self.log_fft_std.to(x.device)
|
||||||
|
|
||||||
|
# Return to expected input shape (b,c,h,w)
|
||||||
|
x_c = x_c.permute(0, 3, 1, 2).reshape(b, c * 2, h, w // 2 + 1)
|
||||||
|
|
||||||
|
return self.net(x_c)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_wide_kernel_vgg(opt_net, opt):
|
||||||
|
""" return a ResNet 18 object
|
||||||
|
"""
|
||||||
|
return WideKernelVgg(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
vgg = WideKernelVgg()
|
||||||
|
vgg(torch.randn(1,3,256,256))
|
|
@ -13,14 +13,14 @@ import torch
|
||||||
def main():
|
def main():
|
||||||
split_img = False
|
split_img = False
|
||||||
opt = {}
|
opt = {}
|
||||||
opt['n_thread'] = 4
|
opt['n_thread'] = 8
|
||||||
opt['compression_level'] = 90 # JPEG compression quality rating.
|
opt['compression_level'] = 95 # JPEG compression quality rating.
|
||||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||||
|
|
||||||
opt['dest'] = 'file'
|
opt['dest'] = 'file'
|
||||||
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'
|
opt['input_folder'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'
|
||||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_v5'
|
opt['save_folder'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked_pt2'
|
||||||
opt['crop_sz'] = [256, 512] # the size of each sub-image
|
opt['crop_sz'] = [256, 512] # the size of each sub-image
|
||||||
opt['step'] = [256, 512] # step of the sliding crop window
|
opt['step'] = [256, 512] # step of the sliding crop window
|
||||||
opt['exclusions'] = [[],[]] # image names matching these terms wont be included in the processing.
|
opt['exclusions'] = [[],[]] # image names matching these terms wont be included in the processing.
|
||||||
|
@ -28,6 +28,8 @@ def main():
|
||||||
opt['resize_final_img'] = [1, .5]
|
opt['resize_final_img'] = [1, .5]
|
||||||
opt['only_resize'] = False
|
opt['only_resize'] = False
|
||||||
opt['vertical_split'] = False
|
opt['vertical_split'] = False
|
||||||
|
opt['use_masking'] = True
|
||||||
|
opt['mask_path'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new_masks'
|
||||||
opt['input_image_max_size_before_being_halved'] = 5500 # As described, images larger than this dimensional size will be halved before anything else is done.
|
opt['input_image_max_size_before_being_halved'] = 5500 # As described, images larger than this dimensional size will be halved before anything else is done.
|
||||||
# This helps prevent images from cameras with "false-megapixels" from polluting the dataset.
|
# This helps prevent images from cameras with "false-megapixels" from polluting the dataset.
|
||||||
# False-megapixel=lots of noise at ultra-high res.
|
# False-megapixel=lots of noise at ultra-high res.
|
||||||
|
@ -150,7 +152,7 @@ class TiledDataset(data.Dataset):
|
||||||
# Wrap in a tuple to align with split mode.
|
# Wrap in a tuple to align with split mode.
|
||||||
return (self.get(index, False, False), None)
|
return (self.get(index, False, False), None)
|
||||||
|
|
||||||
def get_for_scale(self, img, crop_sz, step, resize_factor, ref_resize_factor):
|
def get_for_scale(self, img, mask, crop_sz, step, resize_factor, ref_resize_factor):
|
||||||
thres_sz = self.opt['thres_sz']
|
thres_sz = self.opt['thres_sz']
|
||||||
h, w, c = img.shape
|
h, w, c = img.shape
|
||||||
|
|
||||||
|
@ -172,6 +174,15 @@ class TiledDataset(data.Dataset):
|
||||||
for y in w_space:
|
for y in w_space:
|
||||||
index += 1
|
index += 1
|
||||||
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
|
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
|
||||||
|
if mask is not None:
|
||||||
|
def mask_map(inp):
|
||||||
|
mask_factor = 256 / (crop_sz * ref_resize_factor)
|
||||||
|
return int(inp * mask_factor)
|
||||||
|
crop_mask = mask[mask_map(x):mask_map(x+crop_sz),
|
||||||
|
mask_map(y):mask_map(y+crop_sz),
|
||||||
|
:]
|
||||||
|
if crop_mask.mean() < 255 / 2: # If at least 50% of the image isn't made up of the type of pixels we want to process, ignore this tile.
|
||||||
|
continue
|
||||||
# Center point needs to be resized by ref_resize_factor - since it is relative to the reference image.
|
# Center point needs to be resized by ref_resize_factor - since it is relative to the reference image.
|
||||||
center_point = (int((x + crop_sz // 2) // ref_resize_factor), int((y + crop_sz // 2) // ref_resize_factor))
|
center_point = (int((x + crop_sz // 2) // ref_resize_factor), int((y + crop_sz // 2) // ref_resize_factor))
|
||||||
crop_img = np.ascontiguousarray(crop_img)
|
crop_img = np.ascontiguousarray(crop_img)
|
||||||
|
@ -185,10 +196,11 @@ class TiledDataset(data.Dataset):
|
||||||
def get(self, index, split_mode, left_img):
|
def get(self, index, split_mode, left_img):
|
||||||
path = self.images[index]
|
path = self.images[index]
|
||||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
if img is None or len(img.shape) == 2:
|
if img is None or len(img.shape) == 2:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
mask = cv2.imread(os.path.join(self.opt['mask_path'], os.path.basename(path) + ".png"), cv2.IMREAD_UNCHANGED) if self.opt['use_masking'] else None
|
||||||
|
|
||||||
h, w, c = img.shape
|
h, w, c = img.shape
|
||||||
|
|
||||||
if max(h,w) > self.opt['input_image_max_size_before_being_halved']:
|
if max(h,w) > self.opt['input_image_max_size_before_being_halved']:
|
||||||
|
@ -248,7 +260,7 @@ class TiledDataset(data.Dataset):
|
||||||
break
|
break
|
||||||
if excluded:
|
if excluded:
|
||||||
continue
|
continue
|
||||||
results.extend(self.get_for_scale(img, crop_sz, step, resize_factor, ref_resize_factor))
|
results.extend(self.get_for_scale(img, mask, crop_sz, step, resize_factor, ref_resize_factor))
|
||||||
return results, path
|
return results, path
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
|
@ -299,7 +299,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion_xstart.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -101,8 +101,8 @@ class ConfigurableStep(Module):
|
||||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||||
elif self.step_opt['optimizer'] == 'adamw':
|
elif self.step_opt['optimizer'] == 'adamw':
|
||||||
opt = torch.optim.AdamW(list(optim_params.values()),
|
opt = torch.optim.AdamW(list(optim_params.values()),
|
||||||
weight_decay=opt_config['weight_decay'],
|
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||||
elif self.step_opt['optimizer'] == 'lars':
|
elif self.step_opt['optimizer'] == 'lars':
|
||||||
from trainer.optimizers.larc import LARC
|
from trainer.optimizers.larc import LARC
|
||||||
from trainer.optimizers.sgd import SGDNoBiasMomentum
|
from trainer.optimizers.sgd import SGDNoBiasMomentum
|
||||||
|
|
|
@ -4,25 +4,28 @@ import time
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from torchvision.transforms import CenterCrop
|
||||||
|
|
||||||
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from utils import options as option
|
from utils import options as option
|
||||||
import utils.util as util
|
import utils.util as util
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
from models import create_model
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
bin_path = "f:\\binned"
|
bin_path = "f:\\tmp\\binned"
|
||||||
good_path = "f:\\good"
|
good_path = "f:\\tmp\\good"
|
||||||
os.makedirs(bin_path, exist_ok=True)
|
os.makedirs(bin_path, exist_ok=True)
|
||||||
os.makedirs(good_path, exist_ok=True)
|
os.makedirs(good_path, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/discriminator_filter.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/train_quality_detectors/train_resnet_jpeg.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
opt['dist'] = False
|
opt['dist'] = False
|
||||||
|
@ -43,7 +46,7 @@ if __name__ == "__main__":
|
||||||
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||||
test_loaders.append(test_loader)
|
test_loaders.append(test_loader)
|
||||||
|
|
||||||
model = create_model(opt)
|
model = ExtensibleTrainer(opt)
|
||||||
fea_loss = 0
|
fea_loss = 0
|
||||||
for test_loader in test_loaders:
|
for test_loader in test_loaders:
|
||||||
test_set_name = test_loader.dataset.opt['name']
|
test_set_name = test_loader.dataset.opt['name']
|
||||||
|
@ -55,19 +58,17 @@ if __name__ == "__main__":
|
||||||
tq = tqdm(test_loader)
|
tq = tqdm(test_loader)
|
||||||
removed = 0
|
removed = 0
|
||||||
means = []
|
means = []
|
||||||
dataset_mean = -7.133
|
for k, data in enumerate(tq):
|
||||||
for data in tq:
|
model.feed_data(data, k)
|
||||||
model.feed_data(data, need_GT=True)
|
|
||||||
model.test()
|
model.test()
|
||||||
results = model.eval_state['discriminator_out'][0]
|
results = torch.argmax(torch.nn.functional.softmax(model.eval_state['logits'][0], dim=-1), dim=1)
|
||||||
means.append(torch.mean(results).item())
|
|
||||||
print(sum(means)/len(means), torch.mean(results), torch.max(results), torch.min(results))
|
|
||||||
for i in range(results.shape[0]):
|
for i in range(results.shape[0]):
|
||||||
#if results[i] < .8:
|
if results[i] == 0:
|
||||||
# os.remove(data['GT_path'][i])
|
imname = osp.basename(data['HQ_path'][i])
|
||||||
# removed += 1
|
# For VERIFICATION:
|
||||||
imname = osp.basename(data['GT_path'][i])
|
#torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
|
||||||
if results[i]-dataset_mean > 1:
|
# 4 REALZ:
|
||||||
torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
|
os.remove(data['HQ_path'][i])
|
||||||
|
removed += 1
|
||||||
|
|
||||||
print("Removed %i/%i images" % (removed, len(test_set)))
|
print("Removed %i/%i images" % (removed, len(test_set)))
|
Loading…
Reference in New Issue
Block a user