From 4914c526dcc94babb25650d5d7de269f815a8e07 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 29 Sep 2021 14:24:49 -0600 Subject: [PATCH] More cleanup --- codes/models/RRDBNet_arch.py | 100 ------------ codes/models/vqvae/kmeans_mask_producer.py | 48 ------ codes/scripts/byol/byol_uresnet_playground.py | 142 ------------------ codes/trainer/loss.py | 97 ------------ 4 files changed, 387 deletions(-) delete mode 100644 codes/models/vqvae/kmeans_mask_producer.py delete mode 100644 codes/scripts/byol/byol_uresnet_playground.py diff --git a/codes/models/RRDBNet_arch.py b/codes/models/RRDBNet_arch.py index a70572f0..d8b05259 100644 --- a/codes/models/RRDBNet_arch.py +++ b/codes/models/RRDBNet_arch.py @@ -9,7 +9,6 @@ import torchvision from torchvision.models.resnet import Bottleneck from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu -from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3 from trainer.networks import register_model from utils.util import checkpoint, sequential_checkpoint, opt_get from models.switched_conv.switched_conv import SwitchedConv @@ -305,94 +304,6 @@ class RRDBNet(nn.Module): if hasattr(bm, 'bypass_map'): torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1))) - -class RRDBNetSwitchedConv(nn.Module): - def __init__(self, - in_channels, - out_channels, - mid_channels=64, - num_blocks=23, - growth_channels=32, - body_block=RRDB, - blocks_per_checkpoint=1, - scale=4, - initial_stride=1, - use_ref=False, # When set, a reference image is expected as input and synthesized if not found. Useful for video SR. - resnet_encoder_dict=None - ): - super().__init__() - self.num_blocks = num_blocks - self.blocks_per_checkpoint = blocks_per_checkpoint - self.scale = scale - self.in_channels = in_channels - self.use_ref = use_ref - first_conv_stride = initial_stride if not self.use_ref else scale - first_conv_ksize = 3 if first_conv_stride == 1 else 7 - first_conv_padding = 1 if first_conv_stride == 1 else 3 - self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding) - self.reduce_ch = mid_channels - reduce_to = None - self.body = make_layer( - body_block, - num_blocks, - mid_channels=mid_channels, - growth_channels=growth_channels, - reduce_to=reduce_to) - self.conv_body = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64) - # upsample - self.conv_up1 = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64) - self.conv_up2 = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64) - if scale >= 8: - self.conv_up3 = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64) - else: - self.conv_up3 = None - self.conv_hr = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64) - self.conv_last = SwitchedConv(self.reduce_ch, out_channels, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64) - - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - self.resnet_encoder = UResNet50_3(Bottleneck, [3, 4, 6, 3], out_dim=64) - if resnet_encoder_dict: - self.resnet_encoder.load_state_dict(torch.load(resnet_encoder_dict)) - - for m in [ - self.conv_first, self.conv_body, self.conv_up1, - self.conv_up2, self.conv_up3, self.conv_hr, self.conv_last - ]: - if m is not None: - default_init_weights(m, 0.1) - - def forward(self, x, ref=None): - switch_enc = checkpoint(self.resnet_encoder, F.interpolate(x, scale_factor=2, mode="bilinear")) - - x_lg = x - feat = self.conv_first(x_lg) - feat = sequential_checkpoint(self.body, self.num_blocks // self.blocks_per_checkpoint, feat) - feat = feat[:, :self.reduce_ch] - body_feat = checkpoint(self.conv_body, feat, switch_enc) - feat = feat + body_feat - - # upsample - out = self.lrelu( - checkpoint(self.conv_up1, F.interpolate(feat, scale_factor=2, mode='nearest'), switch_enc)) - if self.scale >= 4: - out = self.lrelu( - checkpoint(self.conv_up2, F.interpolate(out, scale_factor=2, mode='nearest'), switch_enc)) - if self.scale >= 8: - out = self.lrelu( - self.conv_up3(F.interpolate(out, scale_factor=2, mode='nearest'), switch_enc)) - else: - out = self.lrelu(checkpoint(self.conv_up2, out, switch_enc)) - out = checkpoint(self.conv_hr, out, switch_enc) - out = checkpoint(self.conv_last, self.lrelu(out), switch_enc) - return out - - def visual_dbg(self, step, path): - for i, bm in enumerate(self.body): - if hasattr(bm, 'bypass_map'): - torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1))) - - @register_model def register_RRDBNetBypass(opt_net, opt): additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not' @@ -418,14 +329,3 @@ def register_RRDBNet(opt_net, opt): output_mode=output_mode, body_block=RRDB, scale=opt_net['scale'], growth_channels=gc, initial_stride=initial_stride) - -@register_model -def register_rrdb_switched_conv(opt_net, opt): - gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32 - initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1 - bypass_noise = opt_get(opt_net, ['bypass_noise'], False) - block = functools.partial(RRDBWithBypass, randomly_add_noise_to_bypass=bypass_noise) - return RRDBNetSwitchedConv(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], - mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], - body_block=block, scale=opt_net['scale'], growth_channels=gc, - initial_stride=initial_stride, resnet_encoder_dict=opt_net['switch_encoder']) diff --git a/codes/models/vqvae/kmeans_mask_producer.py b/codes/models/vqvae/kmeans_mask_producer.py deleted file mode 100644 index c1ca5fc1..00000000 --- a/codes/models/vqvae/kmeans_mask_producer.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F -from torchvision.models.resnet import Bottleneck - -from models.pixel_level_contrastive_learning.resnet_unet import UResNet50 -from trainer.networks import register_model -from utils.kmeans import kmeans_predict -from utils.util import opt_get - - -class UResnetMaskProducer(nn.Module): - def __init__(self, pretrained_uresnet_path, kmeans_centroid_path, mask_scales=[.125,.25,.5,1], tail_dim=512): - super().__init__() - _, centroids = torch.load(kmeans_centroid_path) - self.centroids = nn.Parameter(centroids) - self.ures = UResNet50(Bottleneck, [3,4,6,3], out_dim=tail_dim).to('cuda') - self.mask_scales = mask_scales - - sd = torch.load(pretrained_uresnet_path) - # An assumption is made that the state_dict came from a byol model. Strip out unnecessary weights.. - resnet_sd = {} - for k, v in sd.items(): - if 'target_encoder.net.' in k: - resnet_sd[k.replace('target_encoder.net.', '')] = v - - self.ures.load_state_dict(resnet_sd, strict=True) - self.ures.eval() - - def forward(self, x): - with torch.no_grad(): - latents = self.ures(x) - b,c,h,w = latents.shape - latents = latents.permute(0,2,3,1).reshape(b*h*w,c) - masks = kmeans_predict(latents, self.centroids).float() - masks = masks.reshape(b,1,h,w) - interpolated_masks = {} - for sf in self.mask_scales: - dim_h, dim_w = int(sf*x.shape[-2]), int(sf*x.shape[-1]) - imask = F.interpolate(masks, size=(dim_h,dim_w), mode="nearest") - interpolated_masks[dim_w] = imask.long() - return interpolated_masks - - -@register_model -def register_uresnet_mask_producer(opt_net, opt): - kw = opt_get(opt_net, ['kwargs'], {}) - return UResnetMaskProducer(**kw) diff --git a/codes/scripts/byol/byol_uresnet_playground.py b/codes/scripts/byol/byol_uresnet_playground.py deleted file mode 100644 index 5bbf69ea..00000000 --- a/codes/scripts/byol/byol_uresnet_playground.py +++ /dev/null @@ -1,142 +0,0 @@ -import os -from random import shuffle - -import matplotlib.cm as cm -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision -from torch.utils.data import DataLoader -from torchvision.models.resnet import Bottleneck -from tqdm import tqdm - -from data.image_folder_dataset import ImageFolderDataset -from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3 - -# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved -# and the distance is computed across the channel dimension. -from utils.kmeans import kmeans, kmeans_predict -from utils.options import dict_to_nonedict - - -def structural_euc_dist(x, y): - diff = torch.square(x - y) - sum = torch.sum(diff, dim=-1) - return torch.sqrt(sum) - - -def cosine_similarity(x, y): - x = norm(x) - y = norm(y) - return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself. - - -def key_value_difference(x, y): - x = F.normalize(x, dim=-1, p=2) - y = F.normalize(y, dim=-1, p=2) - return 2 - 2 * (x * y).sum(dim=-1) - - -def norm(x): - sh = x.shape - sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))]) - return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r) - - -def im_norm(x): - return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5 - - -def get_image_folder_dataloader(batch_size, num_workers, target_size=256): - dataset_opt = dict_to_nonedict({ - 'name': 'amalgam', - #'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'], - #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], - 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], - #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], - 'weights': [1], - 'target_size': target_size, - 'force_multiple': 32, - 'scale': 1 - }) - dataset = ImageFolderDataset(dataset_opt) - return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) - - -def produce_latent_dict(model, basename): - batch_size = 64 - num_workers = 4 - dataloader = get_image_folder_dataloader(batch_size, num_workers) - id = 0 - paths = [] - latents = [] - prob = None - for batch in tqdm(dataloader): - hq = batch['hq'].to('cuda') - l = model(hq) - b, c, h, w = l.shape - dim = b*h*w - l = l.permute(0,2,3,1).reshape(dim, c).cpu() - # extract a random set of 10 latents from each image - if prob is None: - prob = torch.full((dim,), 1/(dim)) - l = l[prob.multinomial(num_samples=100, replacement=False)].split(1, dim=0) - latents.extend(l) - paths.extend(batch['HQ_path']) - id += batch_size - if id > 5000: - print("Saving checkpoint..") - torch.save((latents, paths), f'../{basename}_latent_dict.pth') - id = 0 - - -def build_kmeans(basename): - latents, _ = torch.load(f'../{basename}_latent_dict.pth') - shuffle(latents) - latents = torch.cat(latents, dim=0).to('cuda') - cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'), tol=0, iter_limit=5000, gravity_limit_per_iter=1000) - torch.save((cluster_ids_x, cluster_centers), f'../{basename}_k_means_centroids.pth') - - -def use_kmeans(basename): - output_path = f'../results/{basename}_kmeans_viz' - _, centers = torch.load(f'../{basename}_k_means_centroids.pth') - centers = centers.to('cuda') - batch_size = 8 - num_workers = 0 - dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256) - colormap = cm.get_cmap('viridis', 8) - os.makedirs(output_path, exist_ok=True) - for i, batch in enumerate(tqdm(dataloader)): - hq = batch['hq'].to('cuda') - l = model(hq) - b, c, h, w = l.shape - dim = b*h*w - l = l.permute(0,2,3,1).reshape(dim,c) - pred = kmeans_predict(l, centers) - pred = pred.reshape(b,h,w) - img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy())) - scale = hq.shape[-2] / h - torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=scale, mode="nearest"), - f"{output_path}/{i}_categories.png") - torchvision.utils.save_image(hq, f"{output_path}/{i}_hq.png") - - -if __name__ == '__main__': - pretrained_path = '../experiments/uresnet_pixpro4_imgset.pth' - basename = 'uresnet_pixpro4' - model = UResNet50_3(Bottleneck, [3,4,6,3], out_dim=64).to('cuda') - sd = torch.load(pretrained_path) - resnet_sd = {} - for k, v in sd.items(): - if 'target_encoder.net.' in k: - resnet_sd[k.replace('target_encoder.net.', '')] = v - model.load_state_dict(resnet_sd, strict=True) - model.eval() - - with torch.no_grad(): - #find_similar_latents(model, 0, 8, structural_euc_dist) - #create_latent_database(model, batch_size=32) - #produce_latent_dict(model, basename) - #uild_kmeans(basename) - use_kmeans(basename) diff --git a/codes/trainer/loss.py b/codes/trainer/loss.py index 43f6ae7d..159389fa 100644 --- a/codes/trainer/loss.py +++ b/codes/trainer/loss.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn import numpy as np -from utils.fdpl_util import extract_patches_2d, dct_2d from utils.colors import rgb2ycbcr @@ -56,99 +55,3 @@ class GANLoss(nn.Module): target_label = self.get_target_label(input, target_is_real) loss = self.loss(input.float(), target_label.float()) return loss - - -# Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL -# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see scripts/compute_fdpl_perceptual_weights.py -# In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs). -class FDPLLoss(nn.Module): - """ - Loss function taking the MSE between the 2D DCT coefficients - of predicted and target images or an image channel. - DCT coefficients are computed for each 8x8 block of the image - - Important note about this loss: Since it operates in the frequency domain, precision is highly important. - It works on FP64 numbers and will fail if you attempt to use it with AMP. Recommend you split this loss - off from the rest of amp.scale_loss(). - """ - - def __init__(self, dataset_diff_means_file, device): - """ - dataset_diff_means (torch.tensor): Pre-computed frequency-domain mean differences between LR and HR images. - """ - # These values are derived from the JPEG standard. - qt_Y = torch.tensor([[16, 11, 10, 16, 24, 40, 51, 61], - [12, 12, 14, 19, 26, 58, 60, 55], - [14, 13, 16, 24, 40, 57, 69, 56], - [14, 17, 22, 29, 51, 87, 80, 62], - [18, 22, 37, 56, 68, 109, 103, 77], - [24, 35, 55, 64, 81, 104, 113, 92], - [49, 64, 78, 87, 103, 121, 120, 101], - [72, 92, 95, 98, 112, 100, 103, 99]], - dtype=torch.double, - device=device, - requires_grad=False) - qt_C = torch.tensor([[17, 18, 24, 47, 99, 99, 99, 99], - [18, 21, 26, 66, 99, 99, 99, 99], - [24, 26, 56, 99, 99, 99, 99, 99], - [47, 66, 99, 99, 99, 99, 99, 99], - [99, 99, 99, 99, 99, 99, 99, 99], - [99, 99, 99, 99, 99, 99, 99, 99], - [99, 99, 99, 99, 99, 99, 99, 99], - [99, 99, 99, 99, 99, 99, 99, 99]], - dtype=torch.double, - device=device, - requires_grad=False) - """ - Reasoning behind this perceptual weight matrix: JPEG gives as a model of frequencies that are important - for human perception. In that model, lower frequencies are more important than higher frequencies. Because - of this, the higher frequencies are the first to go during compression. As compression increases, the affect - spreads to the lower frequencies, which degrades perceptual quality. But when the lower frequencies are - preserved, JPEG does an excellent job of compression without a noticeable loss of quality. - In super resolution, we already have the low frequencies. In fact that is really all we have in the low - resolution images. - As evidenced by the diff_means matrix above, what is lost in the SR images is the mid-range frequencies; - those across and towards the centre of the diagonal. We can bias our model to recover these frequencies - by having our loss function prioritize these coefficients, with priority determined by the magnitude of - relative change between the low-res and high-res images. But we can take this further and into a true - preceptual loss by further prioritizing DCT coefficients by the importance that has been assigned to them - by the JPEG quantization table. That is how the table below is created. - - The problem is that we don't know if the JPEG model is optimal. So there is room for qualitative evaluation - of the quantization table values. We can further our perspective weights deleting each in turn for a small - set of images and evaluating the resulting change in percieved quality. I can do this on my own to start and - if it works, I can do a small user study to determine this. - """ - diff_means = torch.tensor(torch.load(dataset_diff_means_file), device=device) - perceptual_weights = torch.stack([(torch.ones_like(qt_Y, device=device) / qt_Y), - (torch.ones_like(qt_C, device=device) / qt_C), - (torch.ones_like(qt_C, device=device) / qt_C)]) - perceptual_weights = perceptual_weights * diff_means - self.perceptual_weights = perceptual_weights / torch.mean(perceptual_weights) - super(FDPLLoss, self).__init__() - - def forward(self, predictions, targets): - """ - Args: - predictions (torch.tensor): output of an image transformation model. - shape: batch_size x 3 x H x W - targets (torch.tensor): ground truth images corresponding to outputs - shape: batch_size x 3 x H x W - criterion (torch.nn.MSELoss): object used to calculate MSE - - Returns: - loss (float): MSE between predicted and ground truth 2D DCT coefficients - """ - # transition to fp64 and then convert to YCC color space. - predictions = rgb2ycbcr(predictions.double()) - targets = rgb2ycbcr(targets.double()) - - # get DCT coefficients of ground truth patches - patches = extract_patches_2d(img=targets, patch_shape=(8, 8), batch_first=True) - ground_truth_dct = dct_2d(patches, norm='ortho') - - # get DCT coefficients of transformed images - patches = extract_patches_2d(img=predictions, patch_shape=(8, 8), batch_first=True) - outputs_dct = dct_2d(patches, norm='ortho') - loss = torch.sum(((outputs_dct - ground_truth_dct).pow(2)) * self.perceptual_weights) - return loss \ No newline at end of file