More cleanup
This commit is contained in:
parent
6e550edfe3
commit
4914c526dc
|
@ -9,7 +9,6 @@ import torchvision
|
|||
from torchvision.models.resnet import Bottleneck
|
||||
|
||||
from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, sequential_checkpoint, opt_get
|
||||
from models.switched_conv.switched_conv import SwitchedConv
|
||||
|
@ -305,94 +304,6 @@ class RRDBNet(nn.Module):
|
|||
if hasattr(bm, 'bypass_map'):
|
||||
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||
|
||||
|
||||
class RRDBNetSwitchedConv(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
mid_channels=64,
|
||||
num_blocks=23,
|
||||
growth_channels=32,
|
||||
body_block=RRDB,
|
||||
blocks_per_checkpoint=1,
|
||||
scale=4,
|
||||
initial_stride=1,
|
||||
use_ref=False, # When set, a reference image is expected as input and synthesized if not found. Useful for video SR.
|
||||
resnet_encoder_dict=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.blocks_per_checkpoint = blocks_per_checkpoint
|
||||
self.scale = scale
|
||||
self.in_channels = in_channels
|
||||
self.use_ref = use_ref
|
||||
first_conv_stride = initial_stride if not self.use_ref else scale
|
||||
first_conv_ksize = 3 if first_conv_stride == 1 else 7
|
||||
first_conv_padding = 1 if first_conv_stride == 1 else 3
|
||||
self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding)
|
||||
self.reduce_ch = mid_channels
|
||||
reduce_to = None
|
||||
self.body = make_layer(
|
||||
body_block,
|
||||
num_blocks,
|
||||
mid_channels=mid_channels,
|
||||
growth_channels=growth_channels,
|
||||
reduce_to=reduce_to)
|
||||
self.conv_body = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64)
|
||||
# upsample
|
||||
self.conv_up1 = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64)
|
||||
self.conv_up2 = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64)
|
||||
if scale >= 8:
|
||||
self.conv_up3 = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64)
|
||||
else:
|
||||
self.conv_up3 = None
|
||||
self.conv_hr = SwitchedConv(self.reduce_ch, self.reduce_ch, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64)
|
||||
self.conv_last = SwitchedConv(self.reduce_ch, out_channels, 3, 8, 1, 1, include_coupler=True, coupler_dim_in=64)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
self.resnet_encoder = UResNet50_3(Bottleneck, [3, 4, 6, 3], out_dim=64)
|
||||
if resnet_encoder_dict:
|
||||
self.resnet_encoder.load_state_dict(torch.load(resnet_encoder_dict))
|
||||
|
||||
for m in [
|
||||
self.conv_first, self.conv_body, self.conv_up1,
|
||||
self.conv_up2, self.conv_up3, self.conv_hr, self.conv_last
|
||||
]:
|
||||
if m is not None:
|
||||
default_init_weights(m, 0.1)
|
||||
|
||||
def forward(self, x, ref=None):
|
||||
switch_enc = checkpoint(self.resnet_encoder, F.interpolate(x, scale_factor=2, mode="bilinear"))
|
||||
|
||||
x_lg = x
|
||||
feat = self.conv_first(x_lg)
|
||||
feat = sequential_checkpoint(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
|
||||
feat = feat[:, :self.reduce_ch]
|
||||
body_feat = checkpoint(self.conv_body, feat, switch_enc)
|
||||
feat = feat + body_feat
|
||||
|
||||
# upsample
|
||||
out = self.lrelu(
|
||||
checkpoint(self.conv_up1, F.interpolate(feat, scale_factor=2, mode='nearest'), switch_enc))
|
||||
if self.scale >= 4:
|
||||
out = self.lrelu(
|
||||
checkpoint(self.conv_up2, F.interpolate(out, scale_factor=2, mode='nearest'), switch_enc))
|
||||
if self.scale >= 8:
|
||||
out = self.lrelu(
|
||||
self.conv_up3(F.interpolate(out, scale_factor=2, mode='nearest'), switch_enc))
|
||||
else:
|
||||
out = self.lrelu(checkpoint(self.conv_up2, out, switch_enc))
|
||||
out = checkpoint(self.conv_hr, out, switch_enc)
|
||||
out = checkpoint(self.conv_last, self.lrelu(out), switch_enc)
|
||||
return out
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
for i, bm in enumerate(self.body):
|
||||
if hasattr(bm, 'bypass_map'):
|
||||
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||
|
||||
|
||||
@register_model
|
||||
def register_RRDBNetBypass(opt_net, opt):
|
||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
|
||||
|
@ -418,14 +329,3 @@ def register_RRDBNet(opt_net, opt):
|
|||
output_mode=output_mode, body_block=RRDB, scale=opt_net['scale'], growth_channels=gc,
|
||||
initial_stride=initial_stride)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_rrdb_switched_conv(opt_net, opt):
|
||||
gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
|
||||
initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
|
||||
bypass_noise = opt_get(opt_net, ['bypass_noise'], False)
|
||||
block = functools.partial(RRDBWithBypass, randomly_add_noise_to_bypass=bypass_noise)
|
||||
return RRDBNetSwitchedConv(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
|
||||
body_block=block, scale=opt_net['scale'], growth_channels=gc,
|
||||
initial_stride=initial_stride, resnet_encoder_dict=opt_net['switch_encoder'])
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.models.resnet import Bottleneck
|
||||
|
||||
from models.pixel_level_contrastive_learning.resnet_unet import UResNet50
|
||||
from trainer.networks import register_model
|
||||
from utils.kmeans import kmeans_predict
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class UResnetMaskProducer(nn.Module):
|
||||
def __init__(self, pretrained_uresnet_path, kmeans_centroid_path, mask_scales=[.125,.25,.5,1], tail_dim=512):
|
||||
super().__init__()
|
||||
_, centroids = torch.load(kmeans_centroid_path)
|
||||
self.centroids = nn.Parameter(centroids)
|
||||
self.ures = UResNet50(Bottleneck, [3,4,6,3], out_dim=tail_dim).to('cuda')
|
||||
self.mask_scales = mask_scales
|
||||
|
||||
sd = torch.load(pretrained_uresnet_path)
|
||||
# An assumption is made that the state_dict came from a byol model. Strip out unnecessary weights..
|
||||
resnet_sd = {}
|
||||
for k, v in sd.items():
|
||||
if 'target_encoder.net.' in k:
|
||||
resnet_sd[k.replace('target_encoder.net.', '')] = v
|
||||
|
||||
self.ures.load_state_dict(resnet_sd, strict=True)
|
||||
self.ures.eval()
|
||||
|
||||
def forward(self, x):
|
||||
with torch.no_grad():
|
||||
latents = self.ures(x)
|
||||
b,c,h,w = latents.shape
|
||||
latents = latents.permute(0,2,3,1).reshape(b*h*w,c)
|
||||
masks = kmeans_predict(latents, self.centroids).float()
|
||||
masks = masks.reshape(b,1,h,w)
|
||||
interpolated_masks = {}
|
||||
for sf in self.mask_scales:
|
||||
dim_h, dim_w = int(sf*x.shape[-2]), int(sf*x.shape[-1])
|
||||
imask = F.interpolate(masks, size=(dim_h,dim_w), mode="nearest")
|
||||
interpolated_masks[dim_w] = imask.long()
|
||||
return interpolated_masks
|
||||
|
||||
|
||||
@register_model
|
||||
def register_uresnet_mask_producer(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
return UResnetMaskProducer(**kw)
|
|
@ -1,142 +0,0 @@
|
|||
import os
|
||||
from random import shuffle
|
||||
|
||||
import matplotlib.cm as cm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.models.resnet import Bottleneck
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.image_folder_dataset import ImageFolderDataset
|
||||
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
|
||||
|
||||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||
# and the distance is computed across the channel dimension.
|
||||
from utils.kmeans import kmeans, kmeans_predict
|
||||
from utils.options import dict_to_nonedict
|
||||
|
||||
|
||||
def structural_euc_dist(x, y):
|
||||
diff = torch.square(x - y)
|
||||
sum = torch.sum(diff, dim=-1)
|
||||
return torch.sqrt(sum)
|
||||
|
||||
|
||||
def cosine_similarity(x, y):
|
||||
x = norm(x)
|
||||
y = norm(y)
|
||||
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
|
||||
|
||||
|
||||
def key_value_difference(x, y):
|
||||
x = F.normalize(x, dim=-1, p=2)
|
||||
y = F.normalize(y, dim=-1, p=2)
|
||||
return 2 - 2 * (x * y).sum(dim=-1)
|
||||
|
||||
|
||||
def norm(x):
|
||||
sh = x.shape
|
||||
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
|
||||
return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r)
|
||||
|
||||
|
||||
def im_norm(x):
|
||||
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
|
||||
|
||||
|
||||
def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
|
||||
dataset_opt = dict_to_nonedict({
|
||||
'name': 'amalgam',
|
||||
#'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||
'weights': [1],
|
||||
'target_size': target_size,
|
||||
'force_multiple': 32,
|
||||
'scale': 1
|
||||
})
|
||||
dataset = ImageFolderDataset(dataset_opt)
|
||||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
||||
|
||||
|
||||
def produce_latent_dict(model, basename):
|
||||
batch_size = 64
|
||||
num_workers = 4
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||
id = 0
|
||||
paths = []
|
||||
latents = []
|
||||
prob = None
|
||||
for batch in tqdm(dataloader):
|
||||
hq = batch['hq'].to('cuda')
|
||||
l = model(hq)
|
||||
b, c, h, w = l.shape
|
||||
dim = b*h*w
|
||||
l = l.permute(0,2,3,1).reshape(dim, c).cpu()
|
||||
# extract a random set of 10 latents from each image
|
||||
if prob is None:
|
||||
prob = torch.full((dim,), 1/(dim))
|
||||
l = l[prob.multinomial(num_samples=100, replacement=False)].split(1, dim=0)
|
||||
latents.extend(l)
|
||||
paths.extend(batch['HQ_path'])
|
||||
id += batch_size
|
||||
if id > 5000:
|
||||
print("Saving checkpoint..")
|
||||
torch.save((latents, paths), f'../{basename}_latent_dict.pth')
|
||||
id = 0
|
||||
|
||||
|
||||
def build_kmeans(basename):
|
||||
latents, _ = torch.load(f'../{basename}_latent_dict.pth')
|
||||
shuffle(latents)
|
||||
latents = torch.cat(latents, dim=0).to('cuda')
|
||||
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'), tol=0, iter_limit=5000, gravity_limit_per_iter=1000)
|
||||
torch.save((cluster_ids_x, cluster_centers), f'../{basename}_k_means_centroids.pth')
|
||||
|
||||
|
||||
def use_kmeans(basename):
|
||||
output_path = f'../results/{basename}_kmeans_viz'
|
||||
_, centers = torch.load(f'../{basename}_k_means_centroids.pth')
|
||||
centers = centers.to('cuda')
|
||||
batch_size = 8
|
||||
num_workers = 0
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256)
|
||||
colormap = cm.get_cmap('viridis', 8)
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
for i, batch in enumerate(tqdm(dataloader)):
|
||||
hq = batch['hq'].to('cuda')
|
||||
l = model(hq)
|
||||
b, c, h, w = l.shape
|
||||
dim = b*h*w
|
||||
l = l.permute(0,2,3,1).reshape(dim,c)
|
||||
pred = kmeans_predict(l, centers)
|
||||
pred = pred.reshape(b,h,w)
|
||||
img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy()))
|
||||
scale = hq.shape[-2] / h
|
||||
torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=scale, mode="nearest"),
|
||||
f"{output_path}/{i}_categories.png")
|
||||
torchvision.utils.save_image(hq, f"{output_path}/{i}_hq.png")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pretrained_path = '../experiments/uresnet_pixpro4_imgset.pth'
|
||||
basename = 'uresnet_pixpro4'
|
||||
model = UResNet50_3(Bottleneck, [3,4,6,3], out_dim=64).to('cuda')
|
||||
sd = torch.load(pretrained_path)
|
||||
resnet_sd = {}
|
||||
for k, v in sd.items():
|
||||
if 'target_encoder.net.' in k:
|
||||
resnet_sd[k.replace('target_encoder.net.', '')] = v
|
||||
model.load_state_dict(resnet_sd, strict=True)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
#find_similar_latents(model, 0, 8, structural_euc_dist)
|
||||
#create_latent_database(model, batch_size=32)
|
||||
#produce_latent_dict(model, basename)
|
||||
#uild_kmeans(basename)
|
||||
use_kmeans(basename)
|
|
@ -1,7 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from utils.fdpl_util import extract_patches_2d, dct_2d
|
||||
from utils.colors import rgb2ycbcr
|
||||
|
||||
|
||||
|
@ -56,99 +55,3 @@ class GANLoss(nn.Module):
|
|||
target_label = self.get_target_label(input, target_is_real)
|
||||
loss = self.loss(input.float(), target_label.float())
|
||||
return loss
|
||||
|
||||
|
||||
# Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL
|
||||
# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see scripts/compute_fdpl_perceptual_weights.py
|
||||
# In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs).
|
||||
class FDPLLoss(nn.Module):
|
||||
"""
|
||||
Loss function taking the MSE between the 2D DCT coefficients
|
||||
of predicted and target images or an image channel.
|
||||
DCT coefficients are computed for each 8x8 block of the image
|
||||
|
||||
Important note about this loss: Since it operates in the frequency domain, precision is highly important.
|
||||
It works on FP64 numbers and will fail if you attempt to use it with AMP. Recommend you split this loss
|
||||
off from the rest of amp.scale_loss().
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_diff_means_file, device):
|
||||
"""
|
||||
dataset_diff_means (torch.tensor): Pre-computed frequency-domain mean differences between LR and HR images.
|
||||
"""
|
||||
# These values are derived from the JPEG standard.
|
||||
qt_Y = torch.tensor([[16, 11, 10, 16, 24, 40, 51, 61],
|
||||
[12, 12, 14, 19, 26, 58, 60, 55],
|
||||
[14, 13, 16, 24, 40, 57, 69, 56],
|
||||
[14, 17, 22, 29, 51, 87, 80, 62],
|
||||
[18, 22, 37, 56, 68, 109, 103, 77],
|
||||
[24, 35, 55, 64, 81, 104, 113, 92],
|
||||
[49, 64, 78, 87, 103, 121, 120, 101],
|
||||
[72, 92, 95, 98, 112, 100, 103, 99]],
|
||||
dtype=torch.double,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
qt_C = torch.tensor([[17, 18, 24, 47, 99, 99, 99, 99],
|
||||
[18, 21, 26, 66, 99, 99, 99, 99],
|
||||
[24, 26, 56, 99, 99, 99, 99, 99],
|
||||
[47, 66, 99, 99, 99, 99, 99, 99],
|
||||
[99, 99, 99, 99, 99, 99, 99, 99],
|
||||
[99, 99, 99, 99, 99, 99, 99, 99],
|
||||
[99, 99, 99, 99, 99, 99, 99, 99],
|
||||
[99, 99, 99, 99, 99, 99, 99, 99]],
|
||||
dtype=torch.double,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
"""
|
||||
Reasoning behind this perceptual weight matrix: JPEG gives as a model of frequencies that are important
|
||||
for human perception. In that model, lower frequencies are more important than higher frequencies. Because
|
||||
of this, the higher frequencies are the first to go during compression. As compression increases, the affect
|
||||
spreads to the lower frequencies, which degrades perceptual quality. But when the lower frequencies are
|
||||
preserved, JPEG does an excellent job of compression without a noticeable loss of quality.
|
||||
In super resolution, we already have the low frequencies. In fact that is really all we have in the low
|
||||
resolution images.
|
||||
As evidenced by the diff_means matrix above, what is lost in the SR images is the mid-range frequencies;
|
||||
those across and towards the centre of the diagonal. We can bias our model to recover these frequencies
|
||||
by having our loss function prioritize these coefficients, with priority determined by the magnitude of
|
||||
relative change between the low-res and high-res images. But we can take this further and into a true
|
||||
preceptual loss by further prioritizing DCT coefficients by the importance that has been assigned to them
|
||||
by the JPEG quantization table. That is how the table below is created.
|
||||
|
||||
The problem is that we don't know if the JPEG model is optimal. So there is room for qualitative evaluation
|
||||
of the quantization table values. We can further our perspective weights deleting each in turn for a small
|
||||
set of images and evaluating the resulting change in percieved quality. I can do this on my own to start and
|
||||
if it works, I can do a small user study to determine this.
|
||||
"""
|
||||
diff_means = torch.tensor(torch.load(dataset_diff_means_file), device=device)
|
||||
perceptual_weights = torch.stack([(torch.ones_like(qt_Y, device=device) / qt_Y),
|
||||
(torch.ones_like(qt_C, device=device) / qt_C),
|
||||
(torch.ones_like(qt_C, device=device) / qt_C)])
|
||||
perceptual_weights = perceptual_weights * diff_means
|
||||
self.perceptual_weights = perceptual_weights / torch.mean(perceptual_weights)
|
||||
super(FDPLLoss, self).__init__()
|
||||
|
||||
def forward(self, predictions, targets):
|
||||
"""
|
||||
Args:
|
||||
predictions (torch.tensor): output of an image transformation model.
|
||||
shape: batch_size x 3 x H x W
|
||||
targets (torch.tensor): ground truth images corresponding to outputs
|
||||
shape: batch_size x 3 x H x W
|
||||
criterion (torch.nn.MSELoss): object used to calculate MSE
|
||||
|
||||
Returns:
|
||||
loss (float): MSE between predicted and ground truth 2D DCT coefficients
|
||||
"""
|
||||
# transition to fp64 and then convert to YCC color space.
|
||||
predictions = rgb2ycbcr(predictions.double())
|
||||
targets = rgb2ycbcr(targets.double())
|
||||
|
||||
# get DCT coefficients of ground truth patches
|
||||
patches = extract_patches_2d(img=targets, patch_shape=(8, 8), batch_first=True)
|
||||
ground_truth_dct = dct_2d(patches, norm='ortho')
|
||||
|
||||
# get DCT coefficients of transformed images
|
||||
patches = extract_patches_2d(img=predictions, patch_shape=(8, 8), batch_first=True)
|
||||
outputs_dct = dct_2d(patches, norm='ortho')
|
||||
loss = torch.sum(((outputs_dct - ground_truth_dct).pow(2)) * self.perceptual_weights)
|
||||
return loss
|
Loading…
Reference in New Issue
Block a user