DL-Art-School/codes/scripts/byol/byol_uresnet_playground.py

155 lines
5.6 KiB
Python
Raw Normal View History

2021-01-06 21:52:17 +00:00
import os
import shutil
2021-01-23 20:46:43 +00:00
from random import shuffle
2021-01-06 21:52:17 +00:00
import matplotlib.cm as cm
2021-01-06 21:52:17 +00:00
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.models.resnet import Bottleneck
from torchvision.transforms import ToTensor, Resize
from tqdm import tqdm
import numpy as np
import utils
from data.image_folder_dataset import ImageFolderDataset
from models.pixel_level_contrastive_learning.resnet_unet import UResNet50
2021-01-23 20:46:43 +00:00
from models.pixel_level_contrastive_learning.resnet_unet_2 import UResNet50_2
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
2021-01-06 21:52:17 +00:00
from models.resnet_with_checkpointing import resnet50
from models.spinenet_arch import SpineNet
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
# and the distance is computed across the channel dimension.
from scripts.byol.byol_spinenet_playground import find_similar_latents, create_latent_database
from utils import util
from utils.kmeans import kmeans, kmeans_predict
2021-01-06 21:52:17 +00:00
from utils.options import dict_to_nonedict
def structural_euc_dist(x, y):
diff = torch.square(x - y)
sum = torch.sum(diff, dim=-1)
return torch.sqrt(sum)
def cosine_similarity(x, y):
x = norm(x)
y = norm(y)
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
def key_value_difference(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
def norm(x):
sh = x.shape
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r)
def im_norm(x):
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
2021-01-06 21:52:17 +00:00
dataset_opt = dict_to_nonedict({
'name': 'amalgam',
2021-01-23 20:46:43 +00:00
#'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'],
2021-01-12 03:09:16 +00:00
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
2021-01-23 20:46:43 +00:00
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
2021-01-06 21:52:17 +00:00
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'weights': [1],
'target_size': target_size,
2021-01-06 21:52:17 +00:00
'force_multiple': 32,
'scale': 1
})
dataset = ImageFolderDataset(dataset_opt)
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
2021-01-23 20:46:43 +00:00
def produce_latent_dict(model, basename):
batch_size = 64
2021-01-07 23:31:28 +00:00
num_workers = 4
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
paths = []
latents = []
prob = None
2021-01-07 23:31:28 +00:00
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda')
l = model(hq)
b, c, h, w = l.shape
dim = b*h*w
l = l.permute(0,2,3,1).reshape(dim, c).cpu()
# extract a random set of 10 latents from each image
if prob is None:
prob = torch.full((dim,), 1/(dim))
2021-01-23 20:46:43 +00:00
l = l[prob.multinomial(num_samples=100, replacement=False)].split(1, dim=0)
2021-01-07 23:31:28 +00:00
latents.extend(l)
paths.extend(batch['HQ_path'])
id += batch_size
2021-01-23 20:46:43 +00:00
if id > 5000:
2021-01-07 23:31:28 +00:00
print("Saving checkpoint..")
2021-01-23 20:46:43 +00:00
torch.save((latents, paths), f'../{basename}_latent_dict.pth')
2021-01-07 23:31:28 +00:00
id = 0
2021-01-23 20:46:43 +00:00
def build_kmeans(basename):
latents, _ = torch.load(f'../{basename}_latent_dict.pth')
shuffle(latents)
latents = torch.cat(latents, dim=0).to('cuda')
2021-01-23 20:46:43 +00:00
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'), tol=0, iter_limit=5000, gravity_limit_per_iter=1000)
torch.save((cluster_ids_x, cluster_centers), f'../{basename}_k_means_centroids.pth')
2021-01-23 20:46:43 +00:00
def use_kmeans(basename):
output_path = f'../results/{basename}_kmeans_viz'
_, centers = torch.load(f'../{basename}_k_means_centroids.pth')
2021-01-12 03:09:16 +00:00
centers = centers.to('cuda')
batch_size = 8
num_workers = 0
2021-01-12 03:09:16 +00:00
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256)
colormap = cm.get_cmap('viridis', 8)
2021-01-23 20:46:43 +00:00
os.makedirs(output_path, exist_ok=True)
for i, batch in enumerate(tqdm(dataloader)):
hq = batch['hq'].to('cuda')
l = model(hq)
b, c, h, w = l.shape
dim = b*h*w
l = l.permute(0,2,3,1).reshape(dim,c)
2021-01-12 03:09:16 +00:00
pred = kmeans_predict(l, centers)
pred = pred.reshape(b,h,w)
2021-01-12 03:09:16 +00:00
img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy()))
2021-01-23 20:46:43 +00:00
scale = hq.shape[-2] / h
torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=scale, mode="nearest"),
f"{output_path}/{i}_categories.png")
torchvision.utils.save_image(hq, f"{output_path}/{i}_hq.png")
2021-01-06 21:52:17 +00:00
if __name__ == '__main__':
2021-01-23 20:46:43 +00:00
pretrained_path = '../experiments/uresnet_pixpro4_imgset.pth'
basename = 'uresnet_pixpro4'
model = UResNet50_3(Bottleneck, [3,4,6,3], out_dim=64).to('cuda')
2021-01-06 21:52:17 +00:00
sd = torch.load(pretrained_path)
resnet_sd = {}
for k, v in sd.items():
if 'target_encoder.net.' in k:
resnet_sd[k.replace('target_encoder.net.', '')] = v
model.load_state_dict(resnet_sd, strict=True)
model.eval()
with torch.no_grad():
2021-01-07 23:31:28 +00:00
#find_similar_latents(model, 0, 8, structural_euc_dist)
2021-01-06 21:52:17 +00:00
#create_latent_database(model, batch_size=32)
2021-01-23 20:46:43 +00:00
#produce_latent_dict(model, basename)
#uild_kmeans(basename)
use_kmeans(basename)