DL-Art-School/codes/scripts/byol/byol_uresnet_playground.py

143 lines
5.1 KiB
Python
Raw Normal View History

2021-01-06 21:52:17 +00:00
import os
2021-01-23 20:46:43 +00:00
from random import shuffle
2021-01-06 21:52:17 +00:00
import matplotlib.cm as cm
2021-01-06 21:52:17 +00:00
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from torchvision.models.resnet import Bottleneck
from tqdm import tqdm
from data.image_folder_dataset import ImageFolderDataset
2021-01-23 20:46:43 +00:00
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
2021-01-06 21:52:17 +00:00
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
# and the distance is computed across the channel dimension.
from utils.kmeans import kmeans, kmeans_predict
2021-01-06 21:52:17 +00:00
from utils.options import dict_to_nonedict
def structural_euc_dist(x, y):
diff = torch.square(x - y)
sum = torch.sum(diff, dim=-1)
return torch.sqrt(sum)
def cosine_similarity(x, y):
x = norm(x)
y = norm(y)
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
def key_value_difference(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
def norm(x):
sh = x.shape
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r)
def im_norm(x):
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
2021-01-06 21:52:17 +00:00
dataset_opt = dict_to_nonedict({
'name': 'amalgam',
2021-01-23 20:46:43 +00:00
#'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'],
2021-01-12 03:09:16 +00:00
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
2021-01-23 20:46:43 +00:00
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
2021-01-06 21:52:17 +00:00
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'weights': [1],
'target_size': target_size,
2021-01-06 21:52:17 +00:00
'force_multiple': 32,
'scale': 1
})
dataset = ImageFolderDataset(dataset_opt)
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
2021-01-23 20:46:43 +00:00
def produce_latent_dict(model, basename):
batch_size = 64
2021-01-07 23:31:28 +00:00
num_workers = 4
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
paths = []
latents = []
prob = None
2021-01-07 23:31:28 +00:00
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda')
l = model(hq)
b, c, h, w = l.shape
dim = b*h*w
l = l.permute(0,2,3,1).reshape(dim, c).cpu()
# extract a random set of 10 latents from each image
if prob is None:
prob = torch.full((dim,), 1/(dim))
2021-01-23 20:46:43 +00:00
l = l[prob.multinomial(num_samples=100, replacement=False)].split(1, dim=0)
2021-01-07 23:31:28 +00:00
latents.extend(l)
paths.extend(batch['HQ_path'])
id += batch_size
2021-01-23 20:46:43 +00:00
if id > 5000:
2021-01-07 23:31:28 +00:00
print("Saving checkpoint..")
2021-01-23 20:46:43 +00:00
torch.save((latents, paths), f'../{basename}_latent_dict.pth')
2021-01-07 23:31:28 +00:00
id = 0
2021-01-23 20:46:43 +00:00
def build_kmeans(basename):
latents, _ = torch.load(f'../{basename}_latent_dict.pth')
shuffle(latents)
latents = torch.cat(latents, dim=0).to('cuda')
2021-01-23 20:46:43 +00:00
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'), tol=0, iter_limit=5000, gravity_limit_per_iter=1000)
torch.save((cluster_ids_x, cluster_centers), f'../{basename}_k_means_centroids.pth')
2021-01-23 20:46:43 +00:00
def use_kmeans(basename):
output_path = f'../results/{basename}_kmeans_viz'
_, centers = torch.load(f'../{basename}_k_means_centroids.pth')
2021-01-12 03:09:16 +00:00
centers = centers.to('cuda')
batch_size = 8
num_workers = 0
2021-01-12 03:09:16 +00:00
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256)
colormap = cm.get_cmap('viridis', 8)
2021-01-23 20:46:43 +00:00
os.makedirs(output_path, exist_ok=True)
for i, batch in enumerate(tqdm(dataloader)):
hq = batch['hq'].to('cuda')
l = model(hq)
b, c, h, w = l.shape
dim = b*h*w
l = l.permute(0,2,3,1).reshape(dim,c)
2021-01-12 03:09:16 +00:00
pred = kmeans_predict(l, centers)
pred = pred.reshape(b,h,w)
2021-01-12 03:09:16 +00:00
img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy()))
2021-01-23 20:46:43 +00:00
scale = hq.shape[-2] / h
torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=scale, mode="nearest"),
f"{output_path}/{i}_categories.png")
torchvision.utils.save_image(hq, f"{output_path}/{i}_hq.png")
2021-01-06 21:52:17 +00:00
if __name__ == '__main__':
2021-01-23 20:46:43 +00:00
pretrained_path = '../experiments/uresnet_pixpro4_imgset.pth'
basename = 'uresnet_pixpro4'
model = UResNet50_3(Bottleneck, [3,4,6,3], out_dim=64).to('cuda')
2021-01-06 21:52:17 +00:00
sd = torch.load(pretrained_path)
resnet_sd = {}
for k, v in sd.items():
if 'target_encoder.net.' in k:
resnet_sd[k.replace('target_encoder.net.', '')] = v
model.load_state_dict(resnet_sd, strict=True)
model.eval()
with torch.no_grad():
2021-01-07 23:31:28 +00:00
#find_similar_latents(model, 0, 8, structural_euc_dist)
2021-01-06 21:52:17 +00:00
#create_latent_database(model, batch_size=32)
2021-01-23 20:46:43 +00:00
#produce_latent_dict(model, basename)
#uild_kmeans(basename)
use_kmeans(basename)