import os from random import shuffle import matplotlib.cm as cm import torch import torch.nn as nn import torch.nn.functional as F import torchvision from torch.utils.data import DataLoader from torchvision.models.resnet import Bottleneck from tqdm import tqdm from data.image_folder_dataset import ImageFolderDataset from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3 # Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved # and the distance is computed across the channel dimension. from utils.kmeans import kmeans, kmeans_predict from utils.options import dict_to_nonedict def structural_euc_dist(x, y): diff = torch.square(x - y) sum = torch.sum(diff, dim=-1) return torch.sqrt(sum) def cosine_similarity(x, y): x = norm(x) y = norm(y) return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself. def key_value_difference(x, y): x = F.normalize(x, dim=-1, p=2) y = F.normalize(y, dim=-1, p=2) return 2 - 2 * (x * y).sum(dim=-1) def norm(x): sh = x.shape sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))]) return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r) def im_norm(x): return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5 def get_image_folder_dataloader(batch_size, num_workers, target_size=256): dataset_opt = dict_to_nonedict({ 'name': 'amalgam', #'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], 'weights': [1], 'target_size': target_size, 'force_multiple': 32, 'scale': 1 }) dataset = ImageFolderDataset(dataset_opt) return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) def produce_latent_dict(model, basename): batch_size = 64 num_workers = 4 dataloader = get_image_folder_dataloader(batch_size, num_workers) id = 0 paths = [] latents = [] prob = None for batch in tqdm(dataloader): hq = batch['hq'].to('cuda') l = model(hq) b, c, h, w = l.shape dim = b*h*w l = l.permute(0,2,3,1).reshape(dim, c).cpu() # extract a random set of 10 latents from each image if prob is None: prob = torch.full((dim,), 1/(dim)) l = l[prob.multinomial(num_samples=100, replacement=False)].split(1, dim=0) latents.extend(l) paths.extend(batch['HQ_path']) id += batch_size if id > 5000: print("Saving checkpoint..") torch.save((latents, paths), f'../{basename}_latent_dict.pth') id = 0 def build_kmeans(basename): latents, _ = torch.load(f'../{basename}_latent_dict.pth') shuffle(latents) latents = torch.cat(latents, dim=0).to('cuda') cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'), tol=0, iter_limit=5000, gravity_limit_per_iter=1000) torch.save((cluster_ids_x, cluster_centers), f'../{basename}_k_means_centroids.pth') def use_kmeans(basename): output_path = f'../results/{basename}_kmeans_viz' _, centers = torch.load(f'../{basename}_k_means_centroids.pth') centers = centers.to('cuda') batch_size = 8 num_workers = 0 dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256) colormap = cm.get_cmap('viridis', 8) os.makedirs(output_path, exist_ok=True) for i, batch in enumerate(tqdm(dataloader)): hq = batch['hq'].to('cuda') l = model(hq) b, c, h, w = l.shape dim = b*h*w l = l.permute(0,2,3,1).reshape(dim,c) pred = kmeans_predict(l, centers) pred = pred.reshape(b,h,w) img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy())) scale = hq.shape[-2] / h torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=scale, mode="nearest"), f"{output_path}/{i}_categories.png") torchvision.utils.save_image(hq, f"{output_path}/{i}_hq.png") if __name__ == '__main__': pretrained_path = '../experiments/uresnet_pixpro4_imgset.pth' basename = 'uresnet_pixpro4' model = UResNet50_3(Bottleneck, [3,4,6,3], out_dim=64).to('cuda') sd = torch.load(pretrained_path) resnet_sd = {} for k, v in sd.items(): if 'target_encoder.net.' in k: resnet_sd[k.replace('target_encoder.net.', '')] = v model.load_state_dict(resnet_sd, strict=True) model.eval() with torch.no_grad(): #find_similar_latents(model, 0, 8, structural_euc_dist) #create_latent_database(model, batch_size=32) #produce_latent_dict(model, basename) #uild_kmeans(basename) use_kmeans(basename)