diff --git a/codes/scripts/byol/byol_uresnet_playground.py b/codes/scripts/byol/byol_uresnet_playground.py index 1eb01068..a09dd83c 100644 --- a/codes/scripts/byol/byol_uresnet_playground.py +++ b/codes/scripts/byol/byol_uresnet_playground.py @@ -1,5 +1,6 @@ import os import shutil +from random import shuffle import matplotlib.cm as cm import torch @@ -16,6 +17,8 @@ import numpy as np import utils from data.image_folder_dataset import ImageFolderDataset from models.pixel_level_contrastive_learning.resnet_unet import UResNet50 +from models.pixel_level_contrastive_learning.resnet_unet_2 import UResNet50_2 +from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3 from models.resnet_with_checkpointing import resnet50 from models.spinenet_arch import SpineNet @@ -59,9 +62,9 @@ def im_norm(x): def get_image_folder_dataloader(batch_size, num_workers, target_size=256): dataset_opt = dict_to_nonedict({ 'name': 'amalgam', - 'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'], + #'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], - #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], + 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], 'weights': [1], 'target_size': target_size, @@ -72,8 +75,8 @@ def get_image_folder_dataloader(batch_size, num_workers, target_size=256): return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) -def produce_latent_dict(model): - batch_size = 32 +def produce_latent_dict(model, basename): + batch_size = 64 num_workers = 4 dataloader = get_image_folder_dataloader(batch_size, num_workers) id = 0 @@ -89,30 +92,33 @@ def produce_latent_dict(model): # extract a random set of 10 latents from each image if prob is None: prob = torch.full((dim,), 1/(dim)) - l = l[prob.multinomial(num_samples=10, replacement=False)].split(1, dim=0) + l = l[prob.multinomial(num_samples=100, replacement=False)].split(1, dim=0) latents.extend(l) paths.extend(batch['HQ_path']) id += batch_size - if id > 1000: + if id > 5000: print("Saving checkpoint..") - torch.save((latents, paths), '../imagenet_latent_dict.pth') + torch.save((latents, paths), f'../{basename}_latent_dict.pth') id = 0 -def build_kmeans(): - latents, _ = torch.load('../imagenet_latent_dict.pth') +def build_kmeans(basename): + latents, _ = torch.load(f'../{basename}_latent_dict.pth') + shuffle(latents) latents = torch.cat(latents, dim=0).to('cuda') - cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=4, distance="euclidean", device=torch.device('cuda:0')) - torch.save((cluster_ids_x, cluster_centers), '../k_means_imagenet.pth') + cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'), tol=0, iter_limit=5000, gravity_limit_per_iter=1000) + torch.save((cluster_ids_x, cluster_centers), f'../{basename}_k_means_centroids.pth') -def use_kmeans(): - _, centers = torch.load('../k_means_imagenet.pth') +def use_kmeans(basename): + output_path = f'../results/{basename}_kmeans_viz' + _, centers = torch.load(f'../{basename}_k_means_centroids.pth') centers = centers.to('cuda') batch_size = 8 num_workers = 0 dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256) colormap = cm.get_cmap('viridis', 8) + os.makedirs(output_path, exist_ok=True) for i, batch in enumerate(tqdm(dataloader)): hq = batch['hq'].to('cuda') l = model(hq) @@ -122,13 +128,16 @@ def use_kmeans(): pred = kmeans_predict(l, centers) pred = pred.reshape(b,h,w) img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy())) - torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=8, mode="nearest"), f"{i}_categories.png") - torchvision.utils.save_image(hq, f"{i}_hq.png") + scale = hq.shape[-2] / h + torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=scale, mode="nearest"), + f"{output_path}/{i}_categories.png") + torchvision.utils.save_image(hq, f"{output_path}/{i}_hq.png") if __name__ == '__main__': - pretrained_path = '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth' - model = UResNet50(Bottleneck, [3,4,6,3], out_dim=256).to('cuda') + pretrained_path = '../experiments/uresnet_pixpro4_imgset.pth' + basename = 'uresnet_pixpro4' + model = UResNet50_3(Bottleneck, [3,4,6,3], out_dim=64).to('cuda') sd = torch.load(pretrained_path) resnet_sd = {} for k, v in sd.items(): @@ -140,6 +149,6 @@ if __name__ == '__main__': with torch.no_grad(): #find_similar_latents(model, 0, 8, structural_euc_dist) #create_latent_database(model, batch_size=32) - #produce_latent_dict(model) - build_kmeans() - #use_kmeans() + #produce_latent_dict(model, basename) + #uild_kmeans(basename) + use_kmeans(basename) diff --git a/codes/utils/kmeans.py b/codes/utils/kmeans.py index 0d21f3a4..1b67ecd5 100644 --- a/codes/utils/kmeans.py +++ b/codes/utils/kmeans.py @@ -1,5 +1,6 @@ # From: https://github.com/subhadarship/kmeans_pytorch # License: https://github.com/subhadarship/kmeans_pytorch/blob/master/LICENSE +import random import numpy as np import torch @@ -30,6 +31,7 @@ def kmeans( tol=1e-4, tqdm_flag=True, iter_limit=0, + gravity_limit_per_iter=None, device=torch.device('cpu') ): """ @@ -83,9 +85,10 @@ def kmeans( for index in range(num_clusters): selected = torch.nonzero(choice_cluster == index).squeeze().to(device) - selected = torch.index_select(X, 0, selected) - + if gravity_limit_per_iter and len(selected) > gravity_limit_per_iter: + ch = random.randint(0, len(selected)-gravity_limit_per_iter) + selected=selected[ch:ch+gravity_limit_per_iter] initial_state[index] = selected.mean(dim=0) center_shift = torch.sum( @@ -97,14 +100,16 @@ def kmeans( iteration = iteration + 1 # update tqdm meter + bins = torch.bincount(choice_cluster) if tqdm_flag: tqdm_meter.set_postfix( iteration=f'{iteration}', - center_shift=f'{center_shift ** 2:0.6f}', - tol=f'{tol:0.6f}' + center_shift=f'{center_shift ** 2}', + tol=f'{tol}', + bins=f'{bins}', ) tqdm_meter.update() - if center_shift ** 2 < tol: + if tol > 0 and center_shift ** 2 < tol: break if iter_limit != 0 and iteration >= iter_limit: break