# From: https://github.com/subhadarship/kmeans_pytorch # License: https://github.com/subhadarship/kmeans_pytorch/blob/master/LICENSE import random import numpy as np import torch from tqdm import tqdm # ToDo: Can't choose a cluster if two points are too close to each other, that's where the nan come from def initialize(X, num_clusters): """ initialize cluster centers :param X: (torch.tensor) matrix :param num_clusters: (int) number of clusters :return: (np.array) initial state """ num_samples = len(X) indices = np.random.choice(num_samples, num_clusters, replace=False) initial_state = X[indices] return initial_state def kmeans( X, num_clusters, distance='euclidean', cluster_centers=[], tol=1e-4, tqdm_flag=True, iter_limit=0, gravity_limit_per_iter=None, device=torch.device('cpu') ): """ perform kmeans :param X: (torch.tensor) matrix :param num_clusters: (int) number of clusters :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] :param tol: (float) threshold [default: 0.0001] :param device: (torch.device) device [default: cpu] :param tqdm_flag: Allows to turn logs on and off :param iter_limit: hard limit for max number of iterations :return: (torch.tensor, torch.tensor) cluster ids, cluster centers """ print(f'running k-means on {device}..') if distance == 'euclidean': pairwise_distance_function = pairwise_distance elif distance == 'cosine': pairwise_distance_function = pairwise_cosine else: raise NotImplementedError # convert to float X = X.float() # transfer to device X = X.to(device) # initialize if type(cluster_centers) == list: # ToDo: make this less annoyingly weird initial_state = initialize(X, num_clusters) else: print('resuming') # find data point closest to the initial cluster center initial_state = cluster_centers dis = pairwise_distance_function(X, initial_state) choice_points = torch.argmin(dis, dim=0) initial_state = X[choice_points] initial_state = initial_state.to(device) iteration = 0 if tqdm_flag: tqdm_meter = tqdm(desc='[running kmeans]') while True: dis = pairwise_distance_function(X, initial_state) choice_cluster = torch.argmin(dis, dim=1) initial_state_pre = initial_state.clone() for index in range(num_clusters): selected = torch.nonzero( choice_cluster == index).squeeze().to(device) selected = torch.index_select(X, 0, selected) if gravity_limit_per_iter and len(selected) > gravity_limit_per_iter: ch = random.randint(0, len(selected)-gravity_limit_per_iter) selected = selected[ch:ch+gravity_limit_per_iter] initial_state[index] = selected.mean(dim=0) center_shift = torch.sum( torch.sqrt( torch.sum((initial_state - initial_state_pre) ** 2, dim=1) )) # increment iteration iteration = iteration + 1 # update tqdm meter bins = torch.bincount(choice_cluster) if tqdm_flag: tqdm_meter.set_postfix( iteration=f'{iteration}', center_shift=f'{center_shift ** 2}', tol=f'{tol}', bins=f'{bins}', ) tqdm_meter.update() if tol > 0 and center_shift ** 2 < tol: break if iter_limit != 0 and iteration >= iter_limit: break return choice_cluster.cpu(), initial_state.cpu() def kmeans_predict( X, cluster_centers, distance='euclidean' ): """ predict using cluster centers :param X: (torch.tensor) matrix :param cluster_centers: (torch.tensor) cluster centers :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] :param device: (torch.device) device [default: 'cpu'] :return: (torch.tensor) cluster ids """ if distance == 'euclidean': pairwise_distance_function = pairwise_distance elif distance == 'cosine': pairwise_distance_function = pairwise_cosine else: raise NotImplementedError dis = pairwise_distance_function(X, cluster_centers) choice_cluster = torch.argmin(dis, dim=1) return choice_cluster def pairwise_distance(data1, data2): # N*1*M A = data1.unsqueeze(dim=1) # 1*N*M B = data2.unsqueeze(dim=0) dis = (A - B) ** 2.0 # return N*N matrix for pairwise distance dis = dis.sum(dim=-1).squeeze() return dis def pairwise_cosine(data1, data2): # N*1*M A = data1.unsqueeze(dim=1) # 1*N*M B = data2.unsqueeze(dim=0) # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5] A_normalized = A / A.norm(dim=-1, keepdim=True) B_normalized = B / B.norm(dim=-1, keepdim=True) cosine = A_normalized * B_normalized # return N*N matrix for pairwise distance cosine_dis = 1 - cosine.sum(dim=-1).squeeze() return cosine_dis