forked from mrq/DL-Art-School
Did anyone ask for k-means clustering?
This is so cool...
This commit is contained in:
parent
197d19714f
commit
5a8156026a
|
@ -20,6 +20,7 @@ from models.spinenet_arch import SpineNet
|
|||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||
# and the distance is computed across the channel dimension.
|
||||
from utils import util
|
||||
from utils.kmeans import kmeans, kmeans_predict
|
||||
from utils.options import dict_to_nonedict
|
||||
|
||||
|
||||
|
@ -51,13 +52,14 @@ def im_norm(x):
|
|||
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
|
||||
|
||||
|
||||
def get_image_folder_dataloader(batch_size, num_workers):
|
||||
def get_image_folder_dataloader(batch_size, num_workers, target_size=224):
|
||||
dataset_opt = dict_to_nonedict({
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||
'weights': [1],
|
||||
'target_size': 224,
|
||||
'target_size': target_size,
|
||||
'force_multiple': 32,
|
||||
'scale': 1
|
||||
})
|
||||
|
@ -119,7 +121,7 @@ def produce_latent_dict(model):
|
|||
id += batch_size
|
||||
if id > 1000:
|
||||
print("Saving checkpoint..")
|
||||
torch.save((latents, paths), '../results.pth')
|
||||
torch.save((latents, paths), '../results_instance_resnet.pth')
|
||||
id = 0
|
||||
|
||||
|
||||
|
@ -160,6 +162,30 @@ def find_similar_latents(model, compare_fn=structural_euc_dist):
|
|||
id = 0
|
||||
|
||||
|
||||
def build_kmeans():
|
||||
latents, _ = torch.load('../results_instance_resnet.pth')
|
||||
latents = torch.cat(latents, dim=0).squeeze().to('cuda')
|
||||
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'))
|
||||
torch.save((cluster_ids_x, cluster_centers), '../k_means_instance_resnet.pth')
|
||||
|
||||
|
||||
def use_kmeans():
|
||||
output = "../results/k_means_instance_resnet/"
|
||||
_, centers = torch.load('../k_means_instance_resnet.pth')
|
||||
batch_size = 8
|
||||
num_workers = 0
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224)
|
||||
for i, batch in enumerate(tqdm(dataloader)):
|
||||
hq = batch['hq'].to('cuda')
|
||||
model(hq)
|
||||
l = layer_hooked_value.clone().squeeze()
|
||||
pred = kmeans_predict(l, centers, device=l.device)
|
||||
for b in range(pred.shape[0]):
|
||||
cat = str(pred[b].item())
|
||||
os.makedirs(os.path.join(output, cat), exist_ok=True)
|
||||
torchvision.utils.save_image(hq[b], os.path.join(output, cat, f'{i}.png'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pretrained_path = '../../../experiments/resnet_byol_diffframe_115k.pth'
|
||||
model = resnet50(pretrained=False).to('cuda')
|
||||
|
@ -168,10 +194,12 @@ if __name__ == '__main__':
|
|||
for k, v in sd.items():
|
||||
if 'target_encoder.net.' in k:
|
||||
resnet_sd[k.replace('target_encoder.net.', '')] = v
|
||||
model.load_state_dict(resnet_sd, strict=True)
|
||||
model.load_state_dict(sd, strict=True)
|
||||
model.eval()
|
||||
register_hook(model, 'avgpool')
|
||||
|
||||
with torch.no_grad():
|
||||
#find_similar_latents(model, structural_euc_dist)
|
||||
produce_latent_dict(model)
|
||||
#produce_latent_dict(model)
|
||||
#build_kmeans()
|
||||
use_kmeans()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import matplotlib.cm as cm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -23,6 +24,7 @@ from models.spinenet_arch import SpineNet
|
|||
# and the distance is computed across the channel dimension.
|
||||
from scripts.byol.byol_spinenet_playground import find_similar_latents, create_latent_database
|
||||
from utils import util
|
||||
from utils.kmeans import kmeans, kmeans_predict
|
||||
from utils.options import dict_to_nonedict
|
||||
|
||||
|
||||
|
@ -54,13 +56,14 @@ def im_norm(x):
|
|||
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
|
||||
|
||||
|
||||
def get_image_folder_dataloader(batch_size, num_workers):
|
||||
def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
|
||||
dataset_opt = dict_to_nonedict({
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||
'weights': [1],
|
||||
'target_size': 256,
|
||||
'target_size': target_size,
|
||||
'force_multiple': 32,
|
||||
'scale': 1
|
||||
})
|
||||
|
@ -75,9 +78,17 @@ def produce_latent_dict(model):
|
|||
id = 0
|
||||
paths = []
|
||||
latents = []
|
||||
prob = None
|
||||
for batch in tqdm(dataloader):
|
||||
hq = batch['hq'].to('cuda')
|
||||
l = model(hq).cpu().split(1, dim=0)
|
||||
l = model(hq)
|
||||
b, c, h, w = l.shape
|
||||
dim = b*h*w
|
||||
l = l.permute(0,2,3,1).reshape(dim, c).cpu()
|
||||
# extract a random set of 10 latents from each image
|
||||
if prob is None:
|
||||
prob = torch.full((dim,), 1/(dim))
|
||||
l = l[prob.multinomial(num_samples=10, replacement=False)].split(1, dim=0)
|
||||
latents.extend(l)
|
||||
paths.extend(batch['HQ_path'])
|
||||
id += batch_size
|
||||
|
@ -87,6 +98,32 @@ def produce_latent_dict(model):
|
|||
id = 0
|
||||
|
||||
|
||||
def build_kmeans():
|
||||
latents, _ = torch.load('../results.pth')
|
||||
latents = torch.cat(latents, dim=0).to('cuda')
|
||||
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'))
|
||||
torch.save((cluster_ids_x, cluster_centers), '../k_means.pth')
|
||||
|
||||
|
||||
def use_kmeans():
|
||||
_, centers = torch.load('../k_means.pth')
|
||||
batch_size = 8
|
||||
num_workers = 0
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=512)
|
||||
colormap = cm.get_cmap('viridis', 8)
|
||||
for i, batch in enumerate(tqdm(dataloader)):
|
||||
hq = batch['hq'].to('cuda')
|
||||
l = model(hq)
|
||||
b, c, h, w = l.shape
|
||||
dim = b*h*w
|
||||
l = l.permute(0,2,3,1).reshape(dim,c)
|
||||
pred = kmeans_predict(l, centers, device=l.device)
|
||||
pred = pred.reshape(b,h,w)
|
||||
img = torch.tensor(colormap(pred[:, :, :].detach().numpy()))
|
||||
torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=8, mode="nearest"), f"{i}_categories.png")
|
||||
torchvision.utils.save_image(hq, f"{i}_hq.png")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pretrained_path = '../experiments/uresnet_pixpro_attempt2.pth'
|
||||
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda')
|
||||
|
@ -101,4 +138,6 @@ if __name__ == '__main__':
|
|||
with torch.no_grad():
|
||||
#find_similar_latents(model, 0, 8, structural_euc_dist)
|
||||
#create_latent_database(model, batch_size=32)
|
||||
produce_latent_dict(model)
|
||||
#produce_latent_dict(model)
|
||||
#build_kmeans()
|
||||
use_kmeans()
|
||||
|
|
|
@ -20,8 +20,8 @@ def main():
|
|||
|
||||
opt['dest'] = 'file'
|
||||
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new']
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'
|
||||
opt['imgsize'] = 256
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_384_full'
|
||||
opt['imgsize'] = 384
|
||||
#opt['bottom_crop'] = 120
|
||||
|
||||
save_folder = opt['save_folder']
|
||||
|
|
184
codes/utils/kmeans.py
Normal file
184
codes/utils/kmeans.py
Normal file
|
@ -0,0 +1,184 @@
|
|||
# From: https://github.com/subhadarship/kmeans_pytorch
|
||||
# License: https://github.com/subhadarship/kmeans_pytorch/blob/master/LICENSE
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
# ToDo: Can't choose a cluster if two points are too close to each other, that's where the nan come from
|
||||
|
||||
|
||||
def initialize(X, num_clusters):
|
||||
"""
|
||||
initialize cluster centers
|
||||
:param X: (torch.tensor) matrix
|
||||
:param num_clusters: (int) number of clusters
|
||||
:return: (np.array) initial state
|
||||
"""
|
||||
num_samples = len(X)
|
||||
indices = np.random.choice(num_samples, num_clusters, replace=False)
|
||||
initial_state = X[indices]
|
||||
return initial_state
|
||||
|
||||
|
||||
def kmeans(
|
||||
X,
|
||||
num_clusters,
|
||||
distance='euclidean',
|
||||
cluster_centers=[],
|
||||
tol=1e-4,
|
||||
tqdm_flag=True,
|
||||
iter_limit=0,
|
||||
device=torch.device('cpu')
|
||||
):
|
||||
"""
|
||||
perform kmeans
|
||||
:param X: (torch.tensor) matrix
|
||||
:param num_clusters: (int) number of clusters
|
||||
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
|
||||
:param tol: (float) threshold [default: 0.0001]
|
||||
:param device: (torch.device) device [default: cpu]
|
||||
:param tqdm_flag: Allows to turn logs on and off
|
||||
:param iter_limit: hard limit for max number of iterations
|
||||
:return: (torch.tensor, torch.tensor) cluster ids, cluster centers
|
||||
"""
|
||||
print(f'running k-means on {device}..')
|
||||
|
||||
if distance == 'euclidean':
|
||||
pairwise_distance_function = pairwise_distance
|
||||
elif distance == 'cosine':
|
||||
pairwise_distance_function = pairwise_cosine
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# convert to float
|
||||
X = X.float()
|
||||
|
||||
# transfer to device
|
||||
X = X.to(device)
|
||||
|
||||
# initialize
|
||||
if type(cluster_centers) == list: # ToDo: make this less annoyingly weird
|
||||
initial_state = initialize(X, num_clusters)
|
||||
else:
|
||||
print('resuming')
|
||||
# find data point closest to the initial cluster center
|
||||
initial_state = cluster_centers
|
||||
dis = pairwise_distance_function(X, initial_state)
|
||||
choice_points = torch.argmin(dis, dim=0)
|
||||
initial_state = X[choice_points]
|
||||
initial_state = initial_state.to(device)
|
||||
|
||||
iteration = 0
|
||||
if tqdm_flag:
|
||||
tqdm_meter = tqdm(desc='[running kmeans]')
|
||||
while True:
|
||||
|
||||
dis = pairwise_distance_function(X, initial_state)
|
||||
|
||||
choice_cluster = torch.argmin(dis, dim=1)
|
||||
|
||||
initial_state_pre = initial_state.clone()
|
||||
|
||||
for index in range(num_clusters):
|
||||
selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
|
||||
|
||||
selected = torch.index_select(X, 0, selected)
|
||||
|
||||
initial_state[index] = selected.mean(dim=0)
|
||||
|
||||
center_shift = torch.sum(
|
||||
torch.sqrt(
|
||||
torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
|
||||
))
|
||||
|
||||
# increment iteration
|
||||
iteration = iteration + 1
|
||||
|
||||
# update tqdm meter
|
||||
if tqdm_flag:
|
||||
tqdm_meter.set_postfix(
|
||||
iteration=f'{iteration}',
|
||||
center_shift=f'{center_shift ** 2:0.6f}',
|
||||
tol=f'{tol:0.6f}'
|
||||
)
|
||||
tqdm_meter.update()
|
||||
if center_shift ** 2 < tol:
|
||||
break
|
||||
if iter_limit != 0 and iteration >= iter_limit:
|
||||
break
|
||||
|
||||
return choice_cluster.cpu(), initial_state.cpu()
|
||||
|
||||
|
||||
def kmeans_predict(
|
||||
X,
|
||||
cluster_centers,
|
||||
distance='euclidean',
|
||||
device=torch.device('cpu')
|
||||
):
|
||||
"""
|
||||
predict using cluster centers
|
||||
:param X: (torch.tensor) matrix
|
||||
:param cluster_centers: (torch.tensor) cluster centers
|
||||
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
|
||||
:param device: (torch.device) device [default: 'cpu']
|
||||
:return: (torch.tensor) cluster ids
|
||||
"""
|
||||
print(f'predicting on {device}..')
|
||||
|
||||
if distance == 'euclidean':
|
||||
pairwise_distance_function = pairwise_distance
|
||||
elif distance == 'cosine':
|
||||
pairwise_distance_function = pairwise_cosine
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# convert to float
|
||||
X = X.float()
|
||||
|
||||
# transfer to device
|
||||
X = X.to(device)
|
||||
|
||||
dis = pairwise_distance_function(X, cluster_centers)
|
||||
choice_cluster = torch.argmin(dis, dim=1)
|
||||
|
||||
return choice_cluster.cpu()
|
||||
|
||||
|
||||
def pairwise_distance(data1, data2, device=torch.device('cpu')):
|
||||
# transfer to device
|
||||
data1, data2 = data1.to(device), data2.to(device)
|
||||
|
||||
# N*1*M
|
||||
A = data1.unsqueeze(dim=1)
|
||||
|
||||
# 1*N*M
|
||||
B = data2.unsqueeze(dim=0)
|
||||
|
||||
dis = (A - B) ** 2.0
|
||||
# return N*N matrix for pairwise distance
|
||||
dis = dis.sum(dim=-1).squeeze()
|
||||
return dis
|
||||
|
||||
|
||||
def pairwise_cosine(data1, data2, device=torch.device('cpu')):
|
||||
# transfer to device
|
||||
data1, data2 = data1.to(device), data2.to(device)
|
||||
|
||||
# N*1*M
|
||||
A = data1.unsqueeze(dim=1)
|
||||
|
||||
# 1*N*M
|
||||
B = data2.unsqueeze(dim=0)
|
||||
|
||||
# normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5]
|
||||
A_normalized = A / A.norm(dim=-1, keepdim=True)
|
||||
B_normalized = B / B.norm(dim=-1, keepdim=True)
|
||||
|
||||
cosine = A_normalized * B_normalized
|
||||
|
||||
# return N*N matrix for pairwise distance
|
||||
cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
|
||||
return cosine_dis
|
Loading…
Reference in New Issue
Block a user