146 lines
5.1 KiB
Python
146 lines
5.1 KiB
Python
import os
|
|
import shutil
|
|
|
|
import matplotlib.cm as cm
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader
|
|
from torchvision.models.resnet import Bottleneck
|
|
from torchvision.transforms import ToTensor, Resize
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
|
|
import utils
|
|
from data.image_folder_dataset import ImageFolderDataset
|
|
from models.pixel_level_contrastive_learning.resnet_unet import UResNet50
|
|
from models.resnet_with_checkpointing import resnet50
|
|
from models.spinenet_arch import SpineNet
|
|
|
|
|
|
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
|
# and the distance is computed across the channel dimension.
|
|
from scripts.byol.byol_spinenet_playground import find_similar_latents, create_latent_database
|
|
from utils import util
|
|
from utils.kmeans import kmeans, kmeans_predict
|
|
from utils.options import dict_to_nonedict
|
|
|
|
|
|
def structural_euc_dist(x, y):
|
|
diff = torch.square(x - y)
|
|
sum = torch.sum(diff, dim=-1)
|
|
return torch.sqrt(sum)
|
|
|
|
|
|
def cosine_similarity(x, y):
|
|
x = norm(x)
|
|
y = norm(y)
|
|
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
|
|
|
|
|
|
def key_value_difference(x, y):
|
|
x = F.normalize(x, dim=-1, p=2)
|
|
y = F.normalize(y, dim=-1, p=2)
|
|
return 2 - 2 * (x * y).sum(dim=-1)
|
|
|
|
|
|
def norm(x):
|
|
sh = x.shape
|
|
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
|
|
return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r)
|
|
|
|
|
|
def im_norm(x):
|
|
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
|
|
|
|
|
|
def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
|
|
dataset_opt = dict_to_nonedict({
|
|
'name': 'amalgam',
|
|
'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'],
|
|
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
|
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
|
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
|
'weights': [1],
|
|
'target_size': target_size,
|
|
'force_multiple': 32,
|
|
'scale': 1
|
|
})
|
|
dataset = ImageFolderDataset(dataset_opt)
|
|
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
|
|
|
|
|
def produce_latent_dict(model):
|
|
batch_size = 32
|
|
num_workers = 4
|
|
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
|
id = 0
|
|
paths = []
|
|
latents = []
|
|
prob = None
|
|
for batch in tqdm(dataloader):
|
|
hq = batch['hq'].to('cuda')
|
|
l = model(hq)
|
|
b, c, h, w = l.shape
|
|
dim = b*h*w
|
|
l = l.permute(0,2,3,1).reshape(dim, c).cpu()
|
|
# extract a random set of 10 latents from each image
|
|
if prob is None:
|
|
prob = torch.full((dim,), 1/(dim))
|
|
l = l[prob.multinomial(num_samples=10, replacement=False)].split(1, dim=0)
|
|
latents.extend(l)
|
|
paths.extend(batch['HQ_path'])
|
|
id += batch_size
|
|
if id > 1000:
|
|
print("Saving checkpoint..")
|
|
torch.save((latents, paths), '../imagenet_latent_dict.pth')
|
|
id = 0
|
|
|
|
|
|
def build_kmeans():
|
|
latents, _ = torch.load('../imagenet_latent_dict.pth')
|
|
latents = torch.cat(latents, dim=0).to('cuda')
|
|
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=4, distance="euclidean", device=torch.device('cuda:0'))
|
|
torch.save((cluster_ids_x, cluster_centers), '../k_means_imagenet.pth')
|
|
|
|
|
|
def use_kmeans():
|
|
_, centers = torch.load('../k_means_imagenet.pth')
|
|
centers = centers.to('cuda')
|
|
batch_size = 8
|
|
num_workers = 0
|
|
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256)
|
|
colormap = cm.get_cmap('viridis', 8)
|
|
for i, batch in enumerate(tqdm(dataloader)):
|
|
hq = batch['hq'].to('cuda')
|
|
l = model(hq)
|
|
b, c, h, w = l.shape
|
|
dim = b*h*w
|
|
l = l.permute(0,2,3,1).reshape(dim,c)
|
|
pred = kmeans_predict(l, centers)
|
|
pred = pred.reshape(b,h,w)
|
|
img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy()))
|
|
torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=8, mode="nearest"), f"{i}_categories.png")
|
|
torchvision.utils.save_image(hq, f"{i}_hq.png")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pretrained_path = '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth'
|
|
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=256).to('cuda')
|
|
sd = torch.load(pretrained_path)
|
|
resnet_sd = {}
|
|
for k, v in sd.items():
|
|
if 'target_encoder.net.' in k:
|
|
resnet_sd[k.replace('target_encoder.net.', '')] = v
|
|
model.load_state_dict(resnet_sd, strict=True)
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
#find_similar_latents(model, 0, 8, structural_euc_dist)
|
|
#create_latent_database(model, batch_size=32)
|
|
#produce_latent_dict(model)
|
|
build_kmeans()
|
|
#use_kmeans()
|