forked from mrq/DL-Art-School
clustering script and injector
This commit is contained in:
parent
ee3b426dae
commit
861aa0e139
91
codes/scripts/gen_kmeans_clusters.py
Normal file
91
codes/scripts/gen_kmeans_clusters.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from random import shuffle
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from pykeops.torch import LazyTensor
|
||||
|
||||
use_cuda = False
|
||||
dtype = torch.float32
|
||||
device_id = 'cpu'
|
||||
|
||||
|
||||
def load_vectors():
|
||||
""" Will need to be modified per-data type you are loading. """
|
||||
all_files = torch.load('/y/separated/large_mel_cheaters_linux.pth')
|
||||
os.makedirs('/y/separated/randomly_sampled_cheaters', exist_ok=True)
|
||||
vecs = []
|
||||
print("Gathering vectors..")
|
||||
j = 0
|
||||
for f in tqdm(all_files):
|
||||
vs=torch.tensor(np.load(f)['arr_0'])
|
||||
for k in range(4):
|
||||
vecs.append(vs[0,:,random.randint(0,vs.shape[-1]-1)])
|
||||
if len(vecs) >= 1000000:
|
||||
vecs = torch.stack(vecs, dim=0)
|
||||
torch.save(vecs, f'/y/separated/randomly_sampled_cheaters/{j}.pth')
|
||||
j += 1
|
||||
vecs = []
|
||||
vecs = [torch.stack(vecs, dim=0)]
|
||||
for i in range(j):
|
||||
vecs.append(torch.load(f'/y/separated/randomly_sampled_cheaters/{i}.pth'))
|
||||
vecs = torch.cat(vecs, dim=0)
|
||||
torch.save(vecs, '/y/separated/randomly_sampled_cheaters/combined.pth')
|
||||
|
||||
def k_means(x, K, Niter=10, verbose=True):
|
||||
"""Implements Lloyd's algorithm for the Euclidean metric.
|
||||
Thanks to https://www.kernel-operations.io/keops/_auto_tutorials/kmeans/plot_kmeans_torch.html
|
||||
"""
|
||||
|
||||
start = time.time()
|
||||
N, D = x.shape # Number of samples, dimension of the ambient space
|
||||
|
||||
c = x[:K, :].clone() # Simplistic initialization for the centroids
|
||||
|
||||
x_i = LazyTensor(x.view(N, 1, D)) # (N, 1, D) samples
|
||||
c_j = LazyTensor(c.view(1, K, D)) # (1, K, D) centroids
|
||||
|
||||
# K-means loop:
|
||||
# - x is the (N, D) point cloud,
|
||||
# - cl is the (N,) vector of class labels
|
||||
# - c is the (K, D) cloud of cluster centroids
|
||||
for i in tqdm(range(Niter)):
|
||||
|
||||
# E step: assign points to the closest cluster -------------------------
|
||||
D_ij = ((x_i - c_j) ** 2).sum(-1) # (N, K) symbolic squared distances
|
||||
cl = D_ij.argmin(dim=1).long().view(-1) # Points -> Nearest cluster
|
||||
|
||||
# M step: update the centroids to the normalized cluster average: ------
|
||||
# Compute the sum of points per cluster:
|
||||
c.zero_()
|
||||
c.scatter_add_(0, cl[:, None].repeat(1, D), x)
|
||||
|
||||
# Divide by the number of points per cluster:
|
||||
Ncl = torch.bincount(cl, minlength=K).type_as(c).view(K, 1)
|
||||
c /= Ncl # in-place division to compute the average
|
||||
|
||||
if verbose: # Fancy display -----------------------------------------------
|
||||
if use_cuda:
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
print(
|
||||
f"K-means for the Euclidean metric with {N:,} points in dimension {D:,}, K = {K:,}:"
|
||||
)
|
||||
print(
|
||||
"Timing for {} iterations: {:.5f}s = {} x {:.5f}s\n".format(
|
||||
Niter, end - start, Niter, (end - start) / Niter
|
||||
)
|
||||
)
|
||||
|
||||
return cl, c
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#load_vectors()
|
||||
vecs = torch.load('/y/separated/randomly_sampled_cheaters/combined.pth')
|
||||
cl, c = k_means(vecs, 8192, 50)
|
||||
torch.save((cl, c), '/y/separated/randomly_sampled_cheaters/k_means_clusters.pth')
|
|
@ -408,5 +408,23 @@ class MusicCheaterLatentInjector(Injector):
|
|||
return {self.output: proj}
|
||||
|
||||
|
||||
class KmeansQuantizer(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
_, self.centroids = torch.load(opt['centroids'])
|
||||
k, b = self.centroids.shape
|
||||
self.centroids = self.centroids.reshape(1, k, b, 1)
|
||||
|
||||
def forward(self, state):
|
||||
with torch.no_grad():
|
||||
x = state[self.input]
|
||||
self.centroids = self.centroids.to(x.device)
|
||||
distances = ((self.centroids - x.unsqueeze(1))**2).sum(2)
|
||||
distances[distances.isnan()] = 9999999999
|
||||
labels = distances.argmin(1)
|
||||
return {self.output: labels}
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('hi')
|
||||
print('hi')
|
||||
|
|
Loading…
Reference in New Issue
Block a user