From 5a8156026a9931736b97e9c1d2e4ded80b2b62a1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 7 Jan 2021 22:37:41 -0700 Subject: [PATCH] Did anyone ask for k-means clustering? This is so cool... --- codes/models/improve_rrdb/__init__.py | 0 codes/scripts/byol/byol_resnet_playground.py | 38 +++- codes/scripts/byol/byol_uresnet_playground.py | 49 ++++- codes/scripts/extract_square_images.py | 4 +- codes/utils/kmeans.py | 184 ++++++++++++++++++ 5 files changed, 263 insertions(+), 12 deletions(-) delete mode 100644 codes/models/improve_rrdb/__init__.py create mode 100644 codes/utils/kmeans.py diff --git a/codes/models/improve_rrdb/__init__.py b/codes/models/improve_rrdb/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/scripts/byol/byol_resnet_playground.py b/codes/scripts/byol/byol_resnet_playground.py index a47b979c..7348fa26 100644 --- a/codes/scripts/byol/byol_resnet_playground.py +++ b/codes/scripts/byol/byol_resnet_playground.py @@ -20,6 +20,7 @@ from models.spinenet_arch import SpineNet # Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved # and the distance is computed across the channel dimension. from utils import util +from utils.kmeans import kmeans, kmeans_predict from utils.options import dict_to_nonedict @@ -51,13 +52,14 @@ def im_norm(x): return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5 -def get_image_folder_dataloader(batch_size, num_workers): +def get_image_folder_dataloader(batch_size, num_workers, target_size=224): dataset_opt = dict_to_nonedict({ 'name': 'amalgam', 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], + #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], 'weights': [1], - 'target_size': 224, + 'target_size': target_size, 'force_multiple': 32, 'scale': 1 }) @@ -119,7 +121,7 @@ def produce_latent_dict(model): id += batch_size if id > 1000: print("Saving checkpoint..") - torch.save((latents, paths), '../results.pth') + torch.save((latents, paths), '../results_instance_resnet.pth') id = 0 @@ -160,6 +162,30 @@ def find_similar_latents(model, compare_fn=structural_euc_dist): id = 0 +def build_kmeans(): + latents, _ = torch.load('../results_instance_resnet.pth') + latents = torch.cat(latents, dim=0).squeeze().to('cuda') + cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0')) + torch.save((cluster_ids_x, cluster_centers), '../k_means_instance_resnet.pth') + + +def use_kmeans(): + output = "../results/k_means_instance_resnet/" + _, centers = torch.load('../k_means_instance_resnet.pth') + batch_size = 8 + num_workers = 0 + dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224) + for i, batch in enumerate(tqdm(dataloader)): + hq = batch['hq'].to('cuda') + model(hq) + l = layer_hooked_value.clone().squeeze() + pred = kmeans_predict(l, centers, device=l.device) + for b in range(pred.shape[0]): + cat = str(pred[b].item()) + os.makedirs(os.path.join(output, cat), exist_ok=True) + torchvision.utils.save_image(hq[b], os.path.join(output, cat, f'{i}.png')) + + if __name__ == '__main__': pretrained_path = '../../../experiments/resnet_byol_diffframe_115k.pth' model = resnet50(pretrained=False).to('cuda') @@ -168,10 +194,12 @@ if __name__ == '__main__': for k, v in sd.items(): if 'target_encoder.net.' in k: resnet_sd[k.replace('target_encoder.net.', '')] = v - model.load_state_dict(resnet_sd, strict=True) + model.load_state_dict(sd, strict=True) model.eval() register_hook(model, 'avgpool') with torch.no_grad(): #find_similar_latents(model, structural_euc_dist) - produce_latent_dict(model) + #produce_latent_dict(model) + #build_kmeans() + use_kmeans() diff --git a/codes/scripts/byol/byol_uresnet_playground.py b/codes/scripts/byol/byol_uresnet_playground.py index 7296dc32..bec5da68 100644 --- a/codes/scripts/byol/byol_uresnet_playground.py +++ b/codes/scripts/byol/byol_uresnet_playground.py @@ -1,6 +1,7 @@ import os import shutil +import matplotlib.cm as cm import torch import torch.nn as nn import torch.nn.functional as F @@ -23,6 +24,7 @@ from models.spinenet_arch import SpineNet # and the distance is computed across the channel dimension. from scripts.byol.byol_spinenet_playground import find_similar_latents, create_latent_database from utils import util +from utils.kmeans import kmeans, kmeans_predict from utils.options import dict_to_nonedict @@ -54,13 +56,14 @@ def im_norm(x): return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5 -def get_image_folder_dataloader(batch_size, num_workers): +def get_image_folder_dataloader(batch_size, num_workers, target_size=256): dataset_opt = dict_to_nonedict({ 'name': 'amalgam', - 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], + 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], + #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], 'weights': [1], - 'target_size': 256, + 'target_size': target_size, 'force_multiple': 32, 'scale': 1 }) @@ -75,9 +78,17 @@ def produce_latent_dict(model): id = 0 paths = [] latents = [] + prob = None for batch in tqdm(dataloader): hq = batch['hq'].to('cuda') - l = model(hq).cpu().split(1, dim=0) + l = model(hq) + b, c, h, w = l.shape + dim = b*h*w + l = l.permute(0,2,3,1).reshape(dim, c).cpu() + # extract a random set of 10 latents from each image + if prob is None: + prob = torch.full((dim,), 1/(dim)) + l = l[prob.multinomial(num_samples=10, replacement=False)].split(1, dim=0) latents.extend(l) paths.extend(batch['HQ_path']) id += batch_size @@ -87,6 +98,32 @@ def produce_latent_dict(model): id = 0 +def build_kmeans(): + latents, _ = torch.load('../results.pth') + latents = torch.cat(latents, dim=0).to('cuda') + cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0')) + torch.save((cluster_ids_x, cluster_centers), '../k_means.pth') + + +def use_kmeans(): + _, centers = torch.load('../k_means.pth') + batch_size = 8 + num_workers = 0 + dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=512) + colormap = cm.get_cmap('viridis', 8) + for i, batch in enumerate(tqdm(dataloader)): + hq = batch['hq'].to('cuda') + l = model(hq) + b, c, h, w = l.shape + dim = b*h*w + l = l.permute(0,2,3,1).reshape(dim,c) + pred = kmeans_predict(l, centers, device=l.device) + pred = pred.reshape(b,h,w) + img = torch.tensor(colormap(pred[:, :, :].detach().numpy())) + torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=8, mode="nearest"), f"{i}_categories.png") + torchvision.utils.save_image(hq, f"{i}_hq.png") + + if __name__ == '__main__': pretrained_path = '../experiments/uresnet_pixpro_attempt2.pth' model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda') @@ -101,4 +138,6 @@ if __name__ == '__main__': with torch.no_grad(): #find_similar_latents(model, 0, 8, structural_euc_dist) #create_latent_database(model, batch_size=32) - produce_latent_dict(model) + #produce_latent_dict(model) + #build_kmeans() + use_kmeans() diff --git a/codes/scripts/extract_square_images.py b/codes/scripts/extract_square_images.py index fff68aee..e5e46501 100644 --- a/codes/scripts/extract_square_images.py +++ b/codes/scripts/extract_square_images.py @@ -20,8 +20,8 @@ def main(): opt['dest'] = 'file' opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'] - opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full' - opt['imgsize'] = 256 + opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_384_full' + opt['imgsize'] = 384 #opt['bottom_crop'] = 120 save_folder = opt['save_folder'] diff --git a/codes/utils/kmeans.py b/codes/utils/kmeans.py new file mode 100644 index 00000000..f5692bb2 --- /dev/null +++ b/codes/utils/kmeans.py @@ -0,0 +1,184 @@ +# From: https://github.com/subhadarship/kmeans_pytorch +# License: https://github.com/subhadarship/kmeans_pytorch/blob/master/LICENSE + +import numpy as np +import torch +from tqdm import tqdm + + +# ToDo: Can't choose a cluster if two points are too close to each other, that's where the nan come from + + +def initialize(X, num_clusters): + """ + initialize cluster centers + :param X: (torch.tensor) matrix + :param num_clusters: (int) number of clusters + :return: (np.array) initial state + """ + num_samples = len(X) + indices = np.random.choice(num_samples, num_clusters, replace=False) + initial_state = X[indices] + return initial_state + + +def kmeans( + X, + num_clusters, + distance='euclidean', + cluster_centers=[], + tol=1e-4, + tqdm_flag=True, + iter_limit=0, + device=torch.device('cpu') +): + """ + perform kmeans + :param X: (torch.tensor) matrix + :param num_clusters: (int) number of clusters + :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] + :param tol: (float) threshold [default: 0.0001] + :param device: (torch.device) device [default: cpu] + :param tqdm_flag: Allows to turn logs on and off + :param iter_limit: hard limit for max number of iterations + :return: (torch.tensor, torch.tensor) cluster ids, cluster centers + """ + print(f'running k-means on {device}..') + + if distance == 'euclidean': + pairwise_distance_function = pairwise_distance + elif distance == 'cosine': + pairwise_distance_function = pairwise_cosine + else: + raise NotImplementedError + + # convert to float + X = X.float() + + # transfer to device + X = X.to(device) + + # initialize + if type(cluster_centers) == list: # ToDo: make this less annoyingly weird + initial_state = initialize(X, num_clusters) + else: + print('resuming') + # find data point closest to the initial cluster center + initial_state = cluster_centers + dis = pairwise_distance_function(X, initial_state) + choice_points = torch.argmin(dis, dim=0) + initial_state = X[choice_points] + initial_state = initial_state.to(device) + + iteration = 0 + if tqdm_flag: + tqdm_meter = tqdm(desc='[running kmeans]') + while True: + + dis = pairwise_distance_function(X, initial_state) + + choice_cluster = torch.argmin(dis, dim=1) + + initial_state_pre = initial_state.clone() + + for index in range(num_clusters): + selected = torch.nonzero(choice_cluster == index).squeeze().to(device) + + selected = torch.index_select(X, 0, selected) + + initial_state[index] = selected.mean(dim=0) + + center_shift = torch.sum( + torch.sqrt( + torch.sum((initial_state - initial_state_pre) ** 2, dim=1) + )) + + # increment iteration + iteration = iteration + 1 + + # update tqdm meter + if tqdm_flag: + tqdm_meter.set_postfix( + iteration=f'{iteration}', + center_shift=f'{center_shift ** 2:0.6f}', + tol=f'{tol:0.6f}' + ) + tqdm_meter.update() + if center_shift ** 2 < tol: + break + if iter_limit != 0 and iteration >= iter_limit: + break + + return choice_cluster.cpu(), initial_state.cpu() + + +def kmeans_predict( + X, + cluster_centers, + distance='euclidean', + device=torch.device('cpu') +): + """ + predict using cluster centers + :param X: (torch.tensor) matrix + :param cluster_centers: (torch.tensor) cluster centers + :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] + :param device: (torch.device) device [default: 'cpu'] + :return: (torch.tensor) cluster ids + """ + print(f'predicting on {device}..') + + if distance == 'euclidean': + pairwise_distance_function = pairwise_distance + elif distance == 'cosine': + pairwise_distance_function = pairwise_cosine + else: + raise NotImplementedError + + # convert to float + X = X.float() + + # transfer to device + X = X.to(device) + + dis = pairwise_distance_function(X, cluster_centers) + choice_cluster = torch.argmin(dis, dim=1) + + return choice_cluster.cpu() + + +def pairwise_distance(data1, data2, device=torch.device('cpu')): + # transfer to device + data1, data2 = data1.to(device), data2.to(device) + + # N*1*M + A = data1.unsqueeze(dim=1) + + # 1*N*M + B = data2.unsqueeze(dim=0) + + dis = (A - B) ** 2.0 + # return N*N matrix for pairwise distance + dis = dis.sum(dim=-1).squeeze() + return dis + + +def pairwise_cosine(data1, data2, device=torch.device('cpu')): + # transfer to device + data1, data2 = data1.to(device), data2.to(device) + + # N*1*M + A = data1.unsqueeze(dim=1) + + # 1*N*M + B = data2.unsqueeze(dim=0) + + # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5] + A_normalized = A / A.norm(dim=-1, keepdim=True) + B_normalized = B / B.norm(dim=-1, keepdim=True) + + cosine = A_normalized * B_normalized + + # return N*N matrix for pairwise distance + cosine_dis = 1 - cosine.sum(dim=-1).squeeze() + return cosine_dis