Did anyone ask for k-means clustering?

This is so cool...
This commit is contained in:
James Betker 2021-01-07 22:37:41 -07:00
parent 197d19714f
commit 5a8156026a
5 changed files with 263 additions and 12 deletions

View File

@ -20,6 +20,7 @@ from models.spinenet_arch import SpineNet
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
# and the distance is computed across the channel dimension.
from utils import util
from utils.kmeans import kmeans, kmeans_predict
from utils.options import dict_to_nonedict
@ -51,13 +52,14 @@ def im_norm(x):
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
def get_image_folder_dataloader(batch_size, num_workers):
def get_image_folder_dataloader(batch_size, num_workers, target_size=224):
dataset_opt = dict_to_nonedict({
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'weights': [1],
'target_size': 224,
'target_size': target_size,
'force_multiple': 32,
'scale': 1
})
@ -119,7 +121,7 @@ def produce_latent_dict(model):
id += batch_size
if id > 1000:
print("Saving checkpoint..")
torch.save((latents, paths), '../results.pth')
torch.save((latents, paths), '../results_instance_resnet.pth')
id = 0
@ -160,6 +162,30 @@ def find_similar_latents(model, compare_fn=structural_euc_dist):
id = 0
def build_kmeans():
latents, _ = torch.load('../results_instance_resnet.pth')
latents = torch.cat(latents, dim=0).squeeze().to('cuda')
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'))
torch.save((cluster_ids_x, cluster_centers), '../k_means_instance_resnet.pth')
def use_kmeans():
output = "../results/k_means_instance_resnet/"
_, centers = torch.load('../k_means_instance_resnet.pth')
batch_size = 8
num_workers = 0
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224)
for i, batch in enumerate(tqdm(dataloader)):
hq = batch['hq'].to('cuda')
model(hq)
l = layer_hooked_value.clone().squeeze()
pred = kmeans_predict(l, centers, device=l.device)
for b in range(pred.shape[0]):
cat = str(pred[b].item())
os.makedirs(os.path.join(output, cat), exist_ok=True)
torchvision.utils.save_image(hq[b], os.path.join(output, cat, f'{i}.png'))
if __name__ == '__main__':
pretrained_path = '../../../experiments/resnet_byol_diffframe_115k.pth'
model = resnet50(pretrained=False).to('cuda')
@ -168,10 +194,12 @@ if __name__ == '__main__':
for k, v in sd.items():
if 'target_encoder.net.' in k:
resnet_sd[k.replace('target_encoder.net.', '')] = v
model.load_state_dict(resnet_sd, strict=True)
model.load_state_dict(sd, strict=True)
model.eval()
register_hook(model, 'avgpool')
with torch.no_grad():
#find_similar_latents(model, structural_euc_dist)
produce_latent_dict(model)
#produce_latent_dict(model)
#build_kmeans()
use_kmeans()

View File

@ -1,6 +1,7 @@
import os
import shutil
import matplotlib.cm as cm
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -23,6 +24,7 @@ from models.spinenet_arch import SpineNet
# and the distance is computed across the channel dimension.
from scripts.byol.byol_spinenet_playground import find_similar_latents, create_latent_database
from utils import util
from utils.kmeans import kmeans, kmeans_predict
from utils.options import dict_to_nonedict
@ -54,13 +56,14 @@ def im_norm(x):
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
def get_image_folder_dataloader(batch_size, num_workers):
def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
dataset_opt = dict_to_nonedict({
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'weights': [1],
'target_size': 256,
'target_size': target_size,
'force_multiple': 32,
'scale': 1
})
@ -75,9 +78,17 @@ def produce_latent_dict(model):
id = 0
paths = []
latents = []
prob = None
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda')
l = model(hq).cpu().split(1, dim=0)
l = model(hq)
b, c, h, w = l.shape
dim = b*h*w
l = l.permute(0,2,3,1).reshape(dim, c).cpu()
# extract a random set of 10 latents from each image
if prob is None:
prob = torch.full((dim,), 1/(dim))
l = l[prob.multinomial(num_samples=10, replacement=False)].split(1, dim=0)
latents.extend(l)
paths.extend(batch['HQ_path'])
id += batch_size
@ -87,6 +98,32 @@ def produce_latent_dict(model):
id = 0
def build_kmeans():
latents, _ = torch.load('../results.pth')
latents = torch.cat(latents, dim=0).to('cuda')
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'))
torch.save((cluster_ids_x, cluster_centers), '../k_means.pth')
def use_kmeans():
_, centers = torch.load('../k_means.pth')
batch_size = 8
num_workers = 0
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=512)
colormap = cm.get_cmap('viridis', 8)
for i, batch in enumerate(tqdm(dataloader)):
hq = batch['hq'].to('cuda')
l = model(hq)
b, c, h, w = l.shape
dim = b*h*w
l = l.permute(0,2,3,1).reshape(dim,c)
pred = kmeans_predict(l, centers, device=l.device)
pred = pred.reshape(b,h,w)
img = torch.tensor(colormap(pred[:, :, :].detach().numpy()))
torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=8, mode="nearest"), f"{i}_categories.png")
torchvision.utils.save_image(hq, f"{i}_hq.png")
if __name__ == '__main__':
pretrained_path = '../experiments/uresnet_pixpro_attempt2.pth'
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda')
@ -101,4 +138,6 @@ if __name__ == '__main__':
with torch.no_grad():
#find_similar_latents(model, 0, 8, structural_euc_dist)
#create_latent_database(model, batch_size=32)
produce_latent_dict(model)
#produce_latent_dict(model)
#build_kmeans()
use_kmeans()

View File

@ -20,8 +20,8 @@ def main():
opt['dest'] = 'file'
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new']
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'
opt['imgsize'] = 256
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_384_full'
opt['imgsize'] = 384
#opt['bottom_crop'] = 120
save_folder = opt['save_folder']

184
codes/utils/kmeans.py Normal file
View File

@ -0,0 +1,184 @@
# From: https://github.com/subhadarship/kmeans_pytorch
# License: https://github.com/subhadarship/kmeans_pytorch/blob/master/LICENSE
import numpy as np
import torch
from tqdm import tqdm
# ToDo: Can't choose a cluster if two points are too close to each other, that's where the nan come from
def initialize(X, num_clusters):
"""
initialize cluster centers
:param X: (torch.tensor) matrix
:param num_clusters: (int) number of clusters
:return: (np.array) initial state
"""
num_samples = len(X)
indices = np.random.choice(num_samples, num_clusters, replace=False)
initial_state = X[indices]
return initial_state
def kmeans(
X,
num_clusters,
distance='euclidean',
cluster_centers=[],
tol=1e-4,
tqdm_flag=True,
iter_limit=0,
device=torch.device('cpu')
):
"""
perform kmeans
:param X: (torch.tensor) matrix
:param num_clusters: (int) number of clusters
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
:param tol: (float) threshold [default: 0.0001]
:param device: (torch.device) device [default: cpu]
:param tqdm_flag: Allows to turn logs on and off
:param iter_limit: hard limit for max number of iterations
:return: (torch.tensor, torch.tensor) cluster ids, cluster centers
"""
print(f'running k-means on {device}..')
if distance == 'euclidean':
pairwise_distance_function = pairwise_distance
elif distance == 'cosine':
pairwise_distance_function = pairwise_cosine
else:
raise NotImplementedError
# convert to float
X = X.float()
# transfer to device
X = X.to(device)
# initialize
if type(cluster_centers) == list: # ToDo: make this less annoyingly weird
initial_state = initialize(X, num_clusters)
else:
print('resuming')
# find data point closest to the initial cluster center
initial_state = cluster_centers
dis = pairwise_distance_function(X, initial_state)
choice_points = torch.argmin(dis, dim=0)
initial_state = X[choice_points]
initial_state = initial_state.to(device)
iteration = 0
if tqdm_flag:
tqdm_meter = tqdm(desc='[running kmeans]')
while True:
dis = pairwise_distance_function(X, initial_state)
choice_cluster = torch.argmin(dis, dim=1)
initial_state_pre = initial_state.clone()
for index in range(num_clusters):
selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
selected = torch.index_select(X, 0, selected)
initial_state[index] = selected.mean(dim=0)
center_shift = torch.sum(
torch.sqrt(
torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
))
# increment iteration
iteration = iteration + 1
# update tqdm meter
if tqdm_flag:
tqdm_meter.set_postfix(
iteration=f'{iteration}',
center_shift=f'{center_shift ** 2:0.6f}',
tol=f'{tol:0.6f}'
)
tqdm_meter.update()
if center_shift ** 2 < tol:
break
if iter_limit != 0 and iteration >= iter_limit:
break
return choice_cluster.cpu(), initial_state.cpu()
def kmeans_predict(
X,
cluster_centers,
distance='euclidean',
device=torch.device('cpu')
):
"""
predict using cluster centers
:param X: (torch.tensor) matrix
:param cluster_centers: (torch.tensor) cluster centers
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
:param device: (torch.device) device [default: 'cpu']
:return: (torch.tensor) cluster ids
"""
print(f'predicting on {device}..')
if distance == 'euclidean':
pairwise_distance_function = pairwise_distance
elif distance == 'cosine':
pairwise_distance_function = pairwise_cosine
else:
raise NotImplementedError
# convert to float
X = X.float()
# transfer to device
X = X.to(device)
dis = pairwise_distance_function(X, cluster_centers)
choice_cluster = torch.argmin(dis, dim=1)
return choice_cluster.cpu()
def pairwise_distance(data1, data2, device=torch.device('cpu')):
# transfer to device
data1, data2 = data1.to(device), data2.to(device)
# N*1*M
A = data1.unsqueeze(dim=1)
# 1*N*M
B = data2.unsqueeze(dim=0)
dis = (A - B) ** 2.0
# return N*N matrix for pairwise distance
dis = dis.sum(dim=-1).squeeze()
return dis
def pairwise_cosine(data1, data2, device=torch.device('cpu')):
# transfer to device
data1, data2 = data1.to(device), data2.to(device)
# N*1*M
A = data1.unsqueeze(dim=1)
# 1*N*M
B = data2.unsqueeze(dim=0)
# normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5]
A_normalized = A / A.norm(dim=-1, keepdim=True)
B_normalized = B / B.norm(dim=-1, keepdim=True)
cosine = A_normalized * B_normalized
# return N*N matrix for pairwise distance
cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
return cosine_dis