More cleanup

This commit is contained in:
James Betker 2021-09-29 14:24:49 -06:00
parent 6e550edfe3
commit 4914c526dc
4 changed files with 0 additions and 387 deletions

View File

@ -9,7 +9,6 @@ import torchvision
from torchvision.models.resnet import Bottleneck
from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
from trainer.networks import register_model
from utils.util import checkpoint, sequential_checkpoint, opt_get
from models.switched_conv.switched_conv import SwitchedConv
@ -305,94 +304,6 @@ class RRDBNet(nn.Module):
if hasattr(bm, 'bypass_map'):
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
class RRDBNetSwitchedConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
mid_channels=64,
num_blocks=23,
growth_channels=32,
body_block=RRDB,
blocks_per_checkpoint=1,
scale=4,
initial_stride=1,
use_ref=False, # When set, a reference image is expected as input and synthesized if not found. Useful for video SR.
resnet_encoder_dict=None
):
super().__init__()
self.num_blocks = num_blocks
self.blocks_per_checkpoint = blocks_per_checkpoint
self.scale = scale
self.in_channels = in_channels
self.use_ref = use_ref
first_conv_stride = initial_stride if not self.use_ref else scale
first_conv_ksize = 3 if first_conv_stride == 1 else 7
first_conv_padding = 1 if first_conv_stride == 1 else 3
self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding)
self.reduce_ch = mid_channels
reduce_to = None
self.body = make_layer(
body_block,
num_blocks,
mid_channels=mid_channels,
growth_channels=growth_channels,
reduce_to=reduce_to)
self.conv_body = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64)
# upsample
self.conv_up1 = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64)
self.conv_up2 = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64)
if scale >= 8:
self.conv_up3 = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64)
else:
self.conv_up3 = None
self.conv_hr = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64)
self.conv_last = SwitchedConv(self.reduce_ch, out_channels, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.resnet_encoder = UResNet50_3(Bottleneck, [3, 4, 6, 3], out_dim=64)
if resnet_encoder_dict:
self.resnet_encoder.load_state_dict(torch.load(resnet_encoder_dict))
for m in [
self.conv_first, self.conv_body, self.conv_up1,
self.conv_up2, self.conv_up3, self.conv_hr, self.conv_last
]:
if m is not None:
default_init_weights(m, 0.1)
def forward(self, x, ref=None):
switch_enc = checkpoint(self.resnet_encoder, F.interpolate(x, scale_factor=2, mode="bilinear"))
x_lg = x
feat = self.conv_first(x_lg)
feat = sequential_checkpoint(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
feat = feat[:, :self.reduce_ch]
body_feat = checkpoint(self.conv_body, feat, switch_enc)
feat = feat + body_feat
# upsample
out = self.lrelu(
checkpoint(self.conv_up1, F.interpolate(feat, scale_factor=2, mode='nearest'), switch_enc))
if self.scale >= 4:
out = self.lrelu(
checkpoint(self.conv_up2, F.interpolate(out, scale_factor=2, mode='nearest'), switch_enc))
if self.scale >= 8:
out = self.lrelu(
self.conv_up3(F.interpolate(out, scale_factor=2, mode='nearest'), switch_enc))
else:
out = self.lrelu(checkpoint(self.conv_up2, out, switch_enc))
out = checkpoint(self.conv_hr, out, switch_enc)
out = checkpoint(self.conv_last, self.lrelu(out), switch_enc)
return out
def visual_dbg(self, step, path):
for i, bm in enumerate(self.body):
if hasattr(bm, 'bypass_map'):
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
@register_model
def register_RRDBNetBypass(opt_net, opt):
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
@ -418,14 +329,3 @@ def register_RRDBNet(opt_net, opt):
output_mode=output_mode, body_block=RRDB, scale=opt_net['scale'], growth_channels=gc,
initial_stride=initial_stride)
@register_model
def register_rrdb_switched_conv(opt_net, opt):
gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
bypass_noise = opt_get(opt_net, ['bypass_noise'], False)
block = functools.partial(RRDBWithBypass, randomly_add_noise_to_bypass=bypass_noise)
return RRDBNetSwitchedConv(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
body_block=block, scale=opt_net['scale'], growth_channels=gc,
initial_stride=initial_stride, resnet_encoder_dict=opt_net['switch_encoder'])

View File

@ -1,48 +0,0 @@
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models.resnet import Bottleneck
from models.pixel_level_contrastive_learning.resnet_unet import UResNet50
from trainer.networks import register_model
from utils.kmeans import kmeans_predict
from utils.util import opt_get
class UResnetMaskProducer(nn.Module):
def __init__(self, pretrained_uresnet_path, kmeans_centroid_path, mask_scales=[.125,.25,.5,1], tail_dim=512):
super().__init__()
_, centroids = torch.load(kmeans_centroid_path)
self.centroids = nn.Parameter(centroids)
self.ures = UResNet50(Bottleneck, [3,4,6,3], out_dim=tail_dim).to('cuda')
self.mask_scales = mask_scales
sd = torch.load(pretrained_uresnet_path)
# An assumption is made that the state_dict came from a byol model. Strip out unnecessary weights..
resnet_sd = {}
for k, v in sd.items():
if 'target_encoder.net.' in k:
resnet_sd[k.replace('target_encoder.net.', '')] = v
self.ures.load_state_dict(resnet_sd, strict=True)
self.ures.eval()
def forward(self, x):
with torch.no_grad():
latents = self.ures(x)
b,c,h,w = latents.shape
latents = latents.permute(0,2,3,1).reshape(b*h*w,c)
masks = kmeans_predict(latents, self.centroids).float()
masks = masks.reshape(b,1,h,w)
interpolated_masks = {}
for sf in self.mask_scales:
dim_h, dim_w = int(sf*x.shape[-2]), int(sf*x.shape[-1])
imask = F.interpolate(masks, size=(dim_h,dim_w), mode="nearest")
interpolated_masks[dim_w] = imask.long()
return interpolated_masks
@register_model
def register_uresnet_mask_producer(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
return UResnetMaskProducer(**kw)

View File

@ -1,142 +0,0 @@
import os
from random import shuffle
import matplotlib.cm as cm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from torchvision.models.resnet import Bottleneck
from tqdm import tqdm
from data.image_folder_dataset import ImageFolderDataset
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
# and the distance is computed across the channel dimension.
from utils.kmeans import kmeans, kmeans_predict
from utils.options import dict_to_nonedict
def structural_euc_dist(x, y):
diff = torch.square(x - y)
sum = torch.sum(diff, dim=-1)
return torch.sqrt(sum)
def cosine_similarity(x, y):
x = norm(x)
y = norm(y)
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
def key_value_difference(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
def norm(x):
sh = x.shape
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r)
def im_norm(x):
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
dataset_opt = dict_to_nonedict({
'name': 'amalgam',
#'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'weights': [1],
'target_size': target_size,
'force_multiple': 32,
'scale': 1
})
dataset = ImageFolderDataset(dataset_opt)
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
def produce_latent_dict(model, basename):
batch_size = 64
num_workers = 4
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
paths = []
latents = []
prob = None
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda')
l = model(hq)
b, c, h, w = l.shape
dim = b*h*w
l = l.permute(0,2,3,1).reshape(dim, c).cpu()
# extract a random set of 10 latents from each image
if prob is None:
prob = torch.full((dim,), 1/(dim))
l = l[prob.multinomial(num_samples=100, replacement=False)].split(1, dim=0)
latents.extend(l)
paths.extend(batch['HQ_path'])
id += batch_size
if id > 5000:
print("Saving checkpoint..")
torch.save((latents, paths), f'../{basename}_latent_dict.pth')
id = 0
def build_kmeans(basename):
latents, _ = torch.load(f'../{basename}_latent_dict.pth')
shuffle(latents)
latents = torch.cat(latents, dim=0).to('cuda')
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'), tol=0, iter_limit=5000, gravity_limit_per_iter=1000)
torch.save((cluster_ids_x, cluster_centers), f'../{basename}_k_means_centroids.pth')
def use_kmeans(basename):
output_path = f'../results/{basename}_kmeans_viz'
_, centers = torch.load(f'../{basename}_k_means_centroids.pth')
centers = centers.to('cuda')
batch_size = 8
num_workers = 0
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256)
colormap = cm.get_cmap('viridis', 8)
os.makedirs(output_path, exist_ok=True)
for i, batch in enumerate(tqdm(dataloader)):
hq = batch['hq'].to('cuda')
l = model(hq)
b, c, h, w = l.shape
dim = b*h*w
l = l.permute(0,2,3,1).reshape(dim,c)
pred = kmeans_predict(l, centers)
pred = pred.reshape(b,h,w)
img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy()))
scale = hq.shape[-2] / h
torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=scale, mode="nearest"),
f"{output_path}/{i}_categories.png")
torchvision.utils.save_image(hq, f"{output_path}/{i}_hq.png")
if __name__ == '__main__':
pretrained_path = '../experiments/uresnet_pixpro4_imgset.pth'
basename = 'uresnet_pixpro4'
model = UResNet50_3(Bottleneck, [3,4,6,3], out_dim=64).to('cuda')
sd = torch.load(pretrained_path)
resnet_sd = {}
for k, v in sd.items():
if 'target_encoder.net.' in k:
resnet_sd[k.replace('target_encoder.net.', '')] = v
model.load_state_dict(resnet_sd, strict=True)
model.eval()
with torch.no_grad():
#find_similar_latents(model, 0, 8, structural_euc_dist)
#create_latent_database(model, batch_size=32)
#produce_latent_dict(model, basename)
#uild_kmeans(basename)
use_kmeans(basename)

View File

@ -1,7 +1,6 @@
import torch
import torch.nn as nn
import numpy as np
from utils.fdpl_util import extract_patches_2d, dct_2d
from utils.colors import rgb2ycbcr
@ -56,99 +55,3 @@ class GANLoss(nn.Module):
target_label = self.get_target_label(input, target_is_real)
loss = self.loss(input.float(), target_label.float())
return loss
# Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL
# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see scripts/compute_fdpl_perceptual_weights.py
# In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs).
class FDPLLoss(nn.Module):
"""
Loss function taking the MSE between the 2D DCT coefficients
of predicted and target images or an image channel.
DCT coefficients are computed for each 8x8 block of the image
Important note about this loss: Since it operates in the frequency domain, precision is highly important.
It works on FP64 numbers and will fail if you attempt to use it with AMP. Recommend you split this loss
off from the rest of amp.scale_loss().
"""
def __init__(self, dataset_diff_means_file, device):
"""
dataset_diff_means (torch.tensor): Pre-computed frequency-domain mean differences between LR and HR images.
"""
# These values are derived from the JPEG standard.
qt_Y = torch.tensor([[16, 11, 10, 16, 24, 40, 51, 61],
[12, 12, 14, 19, 26, 58, 60, 55],
[14, 13, 16, 24, 40, 57, 69, 56],
[14, 17, 22, 29, 51, 87, 80, 62],
[18, 22, 37, 56, 68, 109, 103, 77],
[24, 35, 55, 64, 81, 104, 113, 92],
[49, 64, 78, 87, 103, 121, 120, 101],
[72, 92, 95, 98, 112, 100, 103, 99]],
dtype=torch.double,
device=device,
requires_grad=False)
qt_C = torch.tensor([[17, 18, 24, 47, 99, 99, 99, 99],
[18, 21, 26, 66, 99, 99, 99, 99],
[24, 26, 56, 99, 99, 99, 99, 99],
[47, 66, 99, 99, 99, 99, 99, 99],
[99, 99, 99, 99, 99, 99, 99, 99],
[99, 99, 99, 99, 99, 99, 99, 99],
[99, 99, 99, 99, 99, 99, 99, 99],
[99, 99, 99, 99, 99, 99, 99, 99]],
dtype=torch.double,
device=device,
requires_grad=False)
"""
Reasoning behind this perceptual weight matrix: JPEG gives as a model of frequencies that are important
for human perception. In that model, lower frequencies are more important than higher frequencies. Because
of this, the higher frequencies are the first to go during compression. As compression increases, the affect
spreads to the lower frequencies, which degrades perceptual quality. But when the lower frequencies are
preserved, JPEG does an excellent job of compression without a noticeable loss of quality.
In super resolution, we already have the low frequencies. In fact that is really all we have in the low
resolution images.
As evidenced by the diff_means matrix above, what is lost in the SR images is the mid-range frequencies;
those across and towards the centre of the diagonal. We can bias our model to recover these frequencies
by having our loss function prioritize these coefficients, with priority determined by the magnitude of
relative change between the low-res and high-res images. But we can take this further and into a true
preceptual loss by further prioritizing DCT coefficients by the importance that has been assigned to them
by the JPEG quantization table. That is how the table below is created.
The problem is that we don't know if the JPEG model is optimal. So there is room for qualitative evaluation
of the quantization table values. We can further our perspective weights deleting each in turn for a small
set of images and evaluating the resulting change in percieved quality. I can do this on my own to start and
if it works, I can do a small user study to determine this.
"""
diff_means = torch.tensor(torch.load(dataset_diff_means_file), device=device)
perceptual_weights = torch.stack([(torch.ones_like(qt_Y, device=device) / qt_Y),
(torch.ones_like(qt_C, device=device) / qt_C),
(torch.ones_like(qt_C, device=device) / qt_C)])
perceptual_weights = perceptual_weights * diff_means
self.perceptual_weights = perceptual_weights / torch.mean(perceptual_weights)
super(FDPLLoss, self).__init__()
def forward(self, predictions, targets):
"""
Args:
predictions (torch.tensor): output of an image transformation model.
shape: batch_size x 3 x H x W
targets (torch.tensor): ground truth images corresponding to outputs
shape: batch_size x 3 x H x W
criterion (torch.nn.MSELoss): object used to calculate MSE
Returns:
loss (float): MSE between predicted and ground truth 2D DCT coefficients
"""
# transition to fp64 and then convert to YCC color space.
predictions = rgb2ycbcr(predictions.double())
targets = rgb2ycbcr(targets.double())
# get DCT coefficients of ground truth patches
patches = extract_patches_2d(img=targets, patch_shape=(8, 8), batch_first=True)
ground_truth_dct = dct_2d(patches, norm='ortho')
# get DCT coefficients of transformed images
patches = extract_patches_2d(img=predictions, patch_shape=(8, 8), batch_first=True)
outputs_dct = dct_2d(patches, norm='ortho')
loss = torch.sum(((outputs_dct - ground_truth_dct).pow(2)) * self.perceptual_weights)
return loss