diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index 58073161..796db99a 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -9,6 +9,8 @@ from io import BytesIO # Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data) +from utils.util import opt_get + ''' if __name__ == '__main__': import numpy as np @@ -23,13 +25,15 @@ if __name__ == '__main__': class ImageCorruptor: def __init__(self, opt): self.opt = opt + self.reset_random() self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1 self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else [] self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0 + self.cosine_bias = opt_get(opt, ['cosine_bias'], True) if self.num_corrupts == 0: return - self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else [] - self.reset_random() + else: + self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else [] def reset_random(self): if 'random_seed' in self.opt.keys(): @@ -41,7 +45,10 @@ class ImageCorruptor: # Return is on [0,1] with a bias towards 0. def get_rand(self): r = self.rand.random() - return 1 - cos(r * pi / 2) + if self.cosine_bias: + return 1 - cos(r * pi / 2) + else: + return r def corrupt_images(self, imgs, return_entropy=False): if self.num_corrupts == 0 and not self.fixed_corruptions: @@ -82,10 +89,10 @@ class ImageCorruptor: img = (img // quant_div) * quant_div img = img / 255 elif 'gaussian_blur' in aug: - img = cv2.GaussianBlur(img, (0,0), rand_val*1.5) + img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5) elif 'motion_blur' in aug: # Motion blur - intensity = self.blur_scale * rand_val * 3 + 1 + intensity = self.blur_scale*rand_val * 3 + 1 angle = random.randint(0,360) k = np.zeros((intensity, intensity), dtype=np.float32) k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32) diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index db51c383..18e7560b 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -1,3 +1,4 @@ +import functools import glob import itertools import random @@ -11,7 +12,7 @@ import os import torchvision from torch.utils.data import DataLoader -from torchvision.transforms import Normalize +from torchvision.transforms import Normalize, CenterCrop from tqdm import tqdm from data import util @@ -21,10 +22,19 @@ from data.image_label_parser import VsNetImageLabeler from utils.util import opt_get +def ndarray_center_crop(crop, img): + y, x, c = img.shape + startx = x // 2 - crop // 2 + starty = y // 2 - crop // 2 + return img[starty:starty + crop, startx:startx + crop, :] + + class ImageFolderDataset: def __init__(self, opt): self.opt = opt self.corruptor = ImageCorruptor(opt) + if 'center_crop_hq_sz' in opt.keys(): + self.center_crop = functools.partial(ndarray_center_crop, opt['center_crop_hq_sz']) self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1 self.scale = opt['scale'] @@ -132,6 +142,8 @@ class ImageFolderDataset: def __getitem__(self, item): hq = util.read_img(None, self.image_paths[item], rgb=True) + if hasattr(self, 'center_crop'): + hq = self.center_crop(hq) if not self.disable_flip and random.random() < .5: hq = hq[:, ::-1, :] @@ -223,25 +235,41 @@ class ImageFolderDataset: if __name__ == '__main__': opt = { 'name': 'amalgam', - 'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], + 'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked'], 'weights': [1], 'target_size': 256, - 'force_multiple': 1, 'scale': 2, 'corrupt_before_downsize': True, 'fetch_alt_image': False, 'disable_flip': True, - 'fixed_corruptions': [ 'jpeg-broad', 'gaussian_blur' ], + 'fixed_corruptions': [ 'jpeg-medium' ], 'num_corrupts_per_image': 0, 'corruption_blur_scale': 0 } - ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0) + ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=4, batch_size=64) import os - output_path = 'E:\\4k6k\\datasets\\ns_images\\128_unsupervised' + output_path = 'F:\\tmp' os.makedirs(output_path, exist_ok=True) + res = [] for i, d in tqdm(enumerate(ds)): - lq = d['lq'] - #torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png') + ''' + x = d['hq'] + b,c,h,w = x.shape + x_c = x.view(c*b, h, w) + x_c = torch.view_as_real(torch.fft.rfft(x_c)) + # Log-normalize spectrogram + x_c = (x_c.abs() ** 2).clip(min=1e-8, max=1e16) + x_c = torch.log(x_c) + res.append(x_c) + if i % 100 == 99: + stacked = torch.cat(res, dim=0) + print(stacked.mean(dim=[0,1,2]), stacked.std(dim=[0,1,2])) + ''' + + for k, v in d.items(): + if isinstance(v, torch.Tensor) and len(v.shape) >= 3: + os.makedirs(f'{output_path}\\{k}', exist_ok=True) + torchvision.utils.save_image(v, f'{output_path}\\{k}\\{i}.png') if i >= 200000: - break \ No newline at end of file + break diff --git a/codes/models/classifiers/torch_models.py b/codes/models/classifiers/torch_models.py new file mode 100644 index 00000000..998c304c --- /dev/null +++ b/codes/models/classifiers/torch_models.py @@ -0,0 +1,11 @@ +from torchvision.models import vgg16 + +from trainer.networks import register_model +from utils.util import opt_get + + +@register_model +def register_torch_vgg16(opt_net, opt): + """ return a ResNet 18 object + """ + return vgg16(**opt_get(opt_net, ['kwargs'], {})) diff --git a/codes/models/classifiers/wide_kernel_vgg.py b/codes/models/classifiers/wide_kernel_vgg.py new file mode 100644 index 00000000..d4af0ee0 --- /dev/null +++ b/codes/models/classifiers/wide_kernel_vgg.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn + +from trainer.networks import register_model +from utils.util import opt_get + + +class WideKernelVgg(nn.Module): + def __init__(self, nf=64, num_classes=2): + super().__init__() + self.net = nn.Sequential( + # [64, 128, 128] + nn.Conv2d(6, nf, 7, 1, 3, bias=True), + nn.BatchNorm2d(nf, affine=True), + nn.ReLU(), + nn.Conv2d(nf, nf, 7, 1, 3, bias=False), + nn.BatchNorm2d(nf, affine=True), + nn.ReLU(), + nn.Conv2d(nf, nf, 5, 2, 2, bias=False), + nn.BatchNorm2d(nf, affine=True), + nn.ReLU(), + # [64, 64, 64] + nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False), + nn.BatchNorm2d(nf * 2, affine=True), + nn.ReLU(), + nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(nf * 2, affine=True), + nn.ReLU(), + # [128, 32, 32] + nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False), + nn.BatchNorm2d(nf * 4, affine=True), + nn.ReLU(), + nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(nf * 4, affine=True), + nn.ReLU(), + # [256, 16, 16] + nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False), + nn.BatchNorm2d(nf * 8, affine=True), + nn.ReLU(), + nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False), + nn.BatchNorm2d(nf * 8, affine=True), + nn.ReLU(), + # [512, 8, 8] + nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False), + nn.BatchNorm2d(nf * 8, affine=True), + nn.ReLU(), + nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False), + nn.BatchNorm2d(nf * 8, affine=True), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2), + nn.Flatten(), + nn.Linear(nf * 8 * 4 * 2, 100), + nn.ReLU(), + nn.Linear(100, num_classes) + ) + + # These normalization constants should be derived experimentally. + self.log_fft_mean = torch.tensor([-3.5184, -4.071]).view(1,1,1,2) + self.log_fft_std = torch.tensor([3.1660, 3.8042]).view(1,1,1,2) + + def forward(self, x): + b,c,h,w = x.shape + x_c = x.view(c*b, h, w) + x_c = torch.view_as_real(torch.fft.rfft(x_c)) + + # Log-normalize spectrogram + x_c = (x_c.abs() ** 2).clip(min=1e-8, max=1e16) + x_c = torch.log(x_c) + x_c = (x_c - self.log_fft_mean.to(x.device)) / self.log_fft_std.to(x.device) + + # Return to expected input shape (b,c,h,w) + x_c = x_c.permute(0, 3, 1, 2).reshape(b, c * 2, h, w // 2 + 1) + + return self.net(x_c) + + +@register_model +def register_wide_kernel_vgg(opt_net, opt): + """ return a ResNet 18 object + """ + return WideKernelVgg(**opt_get(opt_net, ['kwargs'], {})) + + +if __name__ == '__main__': + vgg = WideKernelVgg() + vgg(torch.randn(1,3,256,256)) \ No newline at end of file diff --git a/codes/scripts/extract_subimages_with_ref.py b/codes/scripts/extract_subimages_with_ref.py index a7df57c3..63d0180c 100644 --- a/codes/scripts/extract_subimages_with_ref.py +++ b/codes/scripts/extract_subimages_with_ref.py @@ -13,14 +13,14 @@ import torch def main(): split_img = False opt = {} - opt['n_thread'] = 4 - opt['compression_level'] = 90 # JPEG compression quality rating. + opt['n_thread'] = 8 + opt['compression_level'] = 95 # JPEG compression quality rating. # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. opt['dest'] = 'file' - opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new' - opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_v5' + opt['input_folder'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new' + opt['save_folder'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\256_only_humans_masked_pt2' opt['crop_sz'] = [256, 512] # the size of each sub-image opt['step'] = [256, 512] # step of the sliding crop window opt['exclusions'] = [[],[]] # image names matching these terms wont be included in the processing. @@ -28,6 +28,8 @@ def main(): opt['resize_final_img'] = [1, .5] opt['only_resize'] = False opt['vertical_split'] = False + opt['use_masking'] = True + opt['mask_path'] = 'E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new_masks' opt['input_image_max_size_before_being_halved'] = 5500 # As described, images larger than this dimensional size will be halved before anything else is done. # This helps prevent images from cameras with "false-megapixels" from polluting the dataset. # False-megapixel=lots of noise at ultra-high res. @@ -150,7 +152,7 @@ class TiledDataset(data.Dataset): # Wrap in a tuple to align with split mode. return (self.get(index, False, False), None) - def get_for_scale(self, img, crop_sz, step, resize_factor, ref_resize_factor): + def get_for_scale(self, img, mask, crop_sz, step, resize_factor, ref_resize_factor): thres_sz = self.opt['thres_sz'] h, w, c = img.shape @@ -172,6 +174,15 @@ class TiledDataset(data.Dataset): for y in w_space: index += 1 crop_img = img[x:x + crop_sz, y:y + crop_sz, :] + if mask is not None: + def mask_map(inp): + mask_factor = 256 / (crop_sz * ref_resize_factor) + return int(inp * mask_factor) + crop_mask = mask[mask_map(x):mask_map(x+crop_sz), + mask_map(y):mask_map(y+crop_sz), + :] + if crop_mask.mean() < 255 / 2: # If at least 50% of the image isn't made up of the type of pixels we want to process, ignore this tile. + continue # Center point needs to be resized by ref_resize_factor - since it is relative to the reference image. center_point = (int((x + crop_sz // 2) // ref_resize_factor), int((y + crop_sz // 2) // ref_resize_factor)) crop_img = np.ascontiguousarray(crop_img) @@ -185,10 +196,11 @@ class TiledDataset(data.Dataset): def get(self, index, split_mode, left_img): path = self.images[index] img = cv2.imread(path, cv2.IMREAD_UNCHANGED) - if img is None or len(img.shape) == 2: return None + mask = cv2.imread(os.path.join(self.opt['mask_path'], os.path.basename(path) + ".png"), cv2.IMREAD_UNCHANGED) if self.opt['use_masking'] else None + h, w, c = img.shape if max(h,w) > self.opt['input_image_max_size_before_being_halved']: @@ -248,7 +260,7 @@ class TiledDataset(data.Dataset): break if excluded: continue - results.extend(self.get_for_scale(img, crop_sz, step, resize_factor, ref_resize_factor)) + results.extend(self.get_for_scale(img, mask, crop_sz, step, resize_factor, ref_resize_factor)) return results, path def __len__(self): diff --git a/codes/train.py b/codes/train.py index 917be288..daf49b0d 100644 --- a/codes/train.py +++ b/codes/train.py @@ -299,7 +299,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion_xstart.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index c5064c8c..b8d52f9e 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -101,8 +101,8 @@ class ConfigurableStep(Module): betas=(opt_config['beta1'], opt_config['beta2'])) elif self.step_opt['optimizer'] == 'adamw': opt = torch.optim.AdamW(list(optim_params.values()), - weight_decay=opt_config['weight_decay'], - betas=(opt_config['beta1'], opt_config['beta2'])) + weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), + betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) elif self.step_opt['optimizer'] == 'lars': from trainer.optimizers.larc import LARC from trainer.optimizers.sgd import SGDNoBiasMomentum diff --git a/codes/scripts/use_discriminator_as_filter.py b/codes/use_discriminator_as_filter.py similarity index 70% rename from codes/scripts/use_discriminator_as_filter.py rename to codes/use_discriminator_as_filter.py index 14cb6887..d130b0ff 100644 --- a/codes/scripts/use_discriminator_as_filter.py +++ b/codes/use_discriminator_as_filter.py @@ -4,25 +4,28 @@ import time import argparse import os + +from torchvision.transforms import CenterCrop + +from trainer.ExtensibleTrainer import ExtensibleTrainer from utils import options as option import utils.util as util from data import create_dataset, create_dataloader -from models import create_model from tqdm import tqdm import torch import torchvision if __name__ == "__main__": - bin_path = "f:\\binned" - good_path = "f:\\good" + bin_path = "f:\\tmp\\binned" + good_path = "f:\\tmp\\good" os.makedirs(bin_path, exist_ok=True) os.makedirs(good_path, exist_ok=True) torch.backends.cudnn.benchmark = True parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/discriminator_filter.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/train_quality_detectors/train_resnet_jpeg.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) opt['dist'] = False @@ -43,7 +46,7 @@ if __name__ == "__main__": logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) test_loaders.append(test_loader) - model = create_model(opt) + model = ExtensibleTrainer(opt) fea_loss = 0 for test_loader in test_loaders: test_set_name = test_loader.dataset.opt['name'] @@ -55,19 +58,17 @@ if __name__ == "__main__": tq = tqdm(test_loader) removed = 0 means = [] - dataset_mean = -7.133 - for data in tq: - model.feed_data(data, need_GT=True) + for k, data in enumerate(tq): + model.feed_data(data, k) model.test() - results = model.eval_state['discriminator_out'][0] - means.append(torch.mean(results).item()) - print(sum(means)/len(means), torch.mean(results), torch.max(results), torch.min(results)) + results = torch.argmax(torch.nn.functional.softmax(model.eval_state['logits'][0], dim=-1), dim=1) for i in range(results.shape[0]): - #if results[i] < .8: - # os.remove(data['GT_path'][i]) - # removed += 1 - imname = osp.basename(data['GT_path'][i]) - if results[i]-dataset_mean > 1: - torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname)) + if results[i] == 0: + imname = osp.basename(data['HQ_path'][i]) + # For VERIFICATION: + #torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname)) + # 4 REALZ: + os.remove(data['HQ_path'][i]) + removed += 1 print("Removed %i/%i images" % (removed, len(test_set))) \ No newline at end of file