From 39a94c74b59ba0c699e0b4036f4d085b80471355 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Mon, 4 Jan 2021 20:11:29 -0700
Subject: [PATCH] Allow BYOL resnet playground to produce a latent dict

---
 codes/scripts/byol_resnet_playground.py |  25 ++-
 codes/scripts/tsne_torch.py             | 263 ++++++++++++++++++++++++
 2 files changed, 286 insertions(+), 2 deletions(-)

diff --git a/codes/scripts/byol_resnet_playground.py b/codes/scripts/byol_resnet_playground.py
index c21115e9..fe818ea7 100644
--- a/codes/scripts/byol_resnet_playground.py
+++ b/codes/scripts/byol_resnet_playground.py
@@ -103,10 +103,30 @@ def get_latent_for_img(model, img):
     return latent
 
 
+def produce_latent_dict(model):
+    batch_size = 32
+    num_workers = 4
+    dataloader = get_image_folder_dataloader(batch_size, num_workers)
+    id = 0
+    paths = []
+    latents = []
+    for batch in tqdm(dataloader):
+        hq = batch['hq'].to('cuda')
+        model(hq)
+        l = layer_hooked_value.cpu().split(1, dim=0)
+        latents.extend(l)
+        paths.extend(batch['HQ_path'])
+        id += batch_size
+        if id > 1000:
+            print("Saving checkpoint..")
+            torch.save((latents, paths), 'results.pth')
+            id = 0
+
+
 def find_similar_latents(model, compare_fn=structural_euc_dist):
     global layer_hooked_value
 
-    img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\adrianna_xx.jpg'
+    img = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test\\80692045.jpg.jpg'
     #img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
     output_path = '../../results/byol_resnet_similars'
     os.makedirs(output_path, exist_ok=True)
@@ -153,4 +173,5 @@ if __name__ == '__main__':
     register_hook(model, 'avgpool')
 
     with torch.no_grad():
-        find_similar_latents(model, structural_euc_dist)
+        #find_similar_latents(model, structural_euc_dist)
+        produce_latent_dict(model)
diff --git a/codes/scripts/tsne_torch.py b/codes/scripts/tsne_torch.py
index e69de29b..9e08f7fa 100644
--- a/codes/scripts/tsne_torch.py
+++ b/codes/scripts/tsne_torch.py
@@ -0,0 +1,263 @@
+#
+#  tsne_torch.py
+#
+# Implementation of t-SNE in pytorch. The implementation was tested on pytorch
+# > 1.0, and it requires Numpy to read files. In order to plot the results,
+# a working installation of matplotlib is required.
+#
+#
+# The example can be run by executing: `python tsne_torch.py`
+#
+#
+#  Created by Xiao Li on 23-03-2020.
+#  Copyright (c) 2020. All rights reserved.
+from random import shuffle
+
+import numpy as np
+import matplotlib.pyplot as pyplot
+import argparse
+import torch
+from matplotlib.offsetbox import OffsetImage, AnnotationBbox
+from tqdm import tqdm
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--xfile", type=str, default="mnist2500_X.txt", help="file name of feature stored")
+parser.add_argument("--cuda", type=int, default=1, help="if use cuda accelarate")
+
+opt = parser.parse_args()
+print("get choice from args", opt)
+xfile = opt.xfile
+
+if opt.cuda:
+    print("set use cuda")
+    torch.set_default_tensor_type(torch.cuda.DoubleTensor)
+else:
+    torch.set_default_tensor_type(torch.DoubleTensor)
+
+
+def Hbeta_torch(D, beta=1.0):
+    P = torch.exp(-D.clone() * beta)
+
+    sumP = torch.sum(P)
+
+    H = torch.log(sumP) + beta * torch.sum(D * P) / sumP
+    P = P / sumP
+
+    return H, P
+
+
+def x2p_torch(X, tol=1e-5, perplexity=30.0):
+    """
+        Performs a binary search to get P-values in such a way that each
+        conditional Gaussian has the same perplexity.
+    """
+
+    # Initialize some variables
+    print("Computing pairwise distances...")
+    (n, d) = X.shape
+
+    sum_X = torch.sum(X*X, 1)
+    D = torch.add(torch.add(-2 * torch.mm(X, X.t()), sum_X).t(), sum_X)
+
+    P = torch.zeros(n, n)
+    beta = torch.ones(n, 1)
+    logU = torch.log(torch.tensor([perplexity]))
+    n_list = [i for i in range(n)]
+
+    # Loop over all datapoints
+    for i in range(n):
+
+        # Print progress
+        if i % 500 == 0:
+            print("Computing P-values for point %d of %d..." % (i, n))
+
+        # Compute the Gaussian kernel and entropy for the current precision
+        # there may be something wrong with this setting None
+        betamin = None
+        betamax = None
+        Di = D[i, n_list[0:i]+n_list[i+1:n]]
+
+        (H, thisP) = Hbeta_torch(Di, beta[i])
+
+        # Evaluate whether the perplexity is within tolerance
+        Hdiff = H - logU
+        tries = 0
+        while torch.abs(Hdiff) > tol and tries < 50:
+
+            # If not, increase or decrease precision
+            if Hdiff > 0:
+                betamin = beta[i].clone()
+                if betamax is None:
+                    beta[i] = beta[i] * 2.
+                else:
+                    beta[i] = (beta[i] + betamax) / 2.
+            else:
+                betamax = beta[i].clone()
+                if betamin is None:
+                    beta[i] = beta[i] / 2.
+                else:
+                    beta[i] = (beta[i] + betamin) / 2.
+
+            # Recompute the values
+            (H, thisP) = Hbeta_torch(Di, beta[i])
+
+            Hdiff = H - logU
+            tries += 1
+
+        # Set the final row of P
+        P[i, n_list[0:i]+n_list[i+1:n]] = thisP
+
+    # Return final P-matrix
+    return P
+
+
+def pca_torch(X, no_dims=50):
+    print("Preprocessing the data using PCA...")
+    (n, d) = X.shape
+    X = X - torch.mean(X, 0)
+
+    (l, M) = torch.eig(torch.mm(X.t(), X), True)
+    # split M real
+    for i in range(d):
+        if l[i, 1] != 0:
+            M[:, i+1] = M[:, i]
+            i += 1
+
+    Y = torch.mm(X, M[:, 0:no_dims])
+    return Y
+
+
+def tsne(X, no_dims=2, initial_dims=50, perplexity=30.0):
+    """
+        Runs t-SNE on the dataset in the NxD array X to reduce its
+        dimensionality to no_dims dimensions. The syntaxis of the function is
+        `Y = tsne.tsne(X, no_dims, perplexity), where X is an NxD NumPy array.
+    """
+
+    # Check inputs
+    if isinstance(no_dims, float):
+        print("Error: array X should not have type float.")
+        return -1
+    if round(no_dims) != no_dims:
+        print("Error: number of dimensions should be an integer.")
+        return -1
+
+    # Initialize variables
+    X = pca_torch(X, initial_dims).to('cuda')  # Sending to('cuda') after because torch.eig is broken in Windows currently on Ampere GPUs.
+    (n, d) = X.shape
+    max_iter = 1000
+    initial_momentum = 0.5
+    final_momentum = 0.8
+    eta = 500
+    min_gain = 0.01
+    Y = torch.randn(n, no_dims)
+    dY = torch.zeros(n, no_dims)
+    iY = torch.zeros(n, no_dims)
+    gains = torch.ones(n, no_dims)
+
+    # Compute P-values
+    P = x2p_torch(X, 1e-5, perplexity)
+    P = P + P.t()
+    P = P / torch.sum(P)
+    P = P * 4.    # early exaggeration
+    print("get P shape", P.shape)
+    P = torch.max(P, torch.tensor([1e-21]))
+
+    # Run iterations
+    for iter in tqdm(range(max_iter)):
+
+        # Compute pairwise affinities
+        sum_Y = torch.sum(Y*Y, 1)
+        num = -2. * torch.mm(Y, Y.t())
+        num = 1. / (1. + torch.add(torch.add(num, sum_Y).t(), sum_Y))
+        num[range(n), range(n)] = 0.
+        Q = num / torch.sum(num)
+        Q = torch.max(Q, torch.tensor([1e-12]))
+
+        # Compute gradient
+        PQ = P - Q
+        for i in range(n):
+            dY[i, :] = torch.sum((PQ[:, i] * num[:, i]).repeat(no_dims, 1).t() * (Y[i, :] - Y), 0)
+
+        # Perform the update
+        if iter < 20:
+            momentum = initial_momentum
+        else:
+            momentum = final_momentum
+
+        gains = (gains + 0.2) * ((dY > 0.) != (iY > 0.)).double() + (gains * 0.8) * ((dY > 0.) == (iY > 0.)).double()
+        gains[gains < min_gain] = min_gain
+        iY = momentum * iY - eta * (gains * dY)
+        Y = Y + iY
+        Y = Y - torch.mean(Y, 0)
+
+        # Compute current value of cost function
+        if (iter + 1) % 10 == 0:
+            C = torch.sum(P * torch.log(P / Q))
+            print("Iteration %d: error is %f" % (iter + 1, C))
+
+        # Stop lying about P-values
+        if iter == 100:
+            P = P / 4.
+
+    # Return solution
+    return Y
+
+
+def run_tsne():
+    print("Run Y = tsne.tsne(X, no_dims, perplexity) to perform t-SNE on your dataset.")
+
+    limit = 4000
+    X, files = torch.load('results.pth')
+    zipped = list(zip(X, files))
+    shuffle(zipped)
+    X, files = zip(*zipped)
+    X = torch.cat(X, dim=0).squeeze()[:limit]
+    labels = np.zeros(X.shape[0])  # We don't have any labels..
+
+    # confirm that x file get same number point than label file
+    # otherwise may cause error in scatter
+    assert(len(X[:, 0])==len(X[:,1]))
+    assert(len(X)==len(labels))
+
+    with torch.no_grad():
+        Y = tsne(X, 2, 2048, 20.0)
+
+    if opt.cuda:
+        Y = Y.cpu().numpy()
+
+    # You may write result in two files
+    # print("Save Y values in file")
+    # Y1 = open("y1.txt", 'w')
+    # Y2 = open('y2.txt', 'w')
+    # for i in range(Y.shape[0]):
+    #     Y1.write(str(Y[i,0])+"\n")
+    #     Y2.write(str(Y[i,1])+"\n")
+
+    pyplot.scatter(Y[:, 0], Y[:, 1], 20, labels)
+    pyplot.show()
+    torch.save((Y, files[:limit]), "tsne_output.pth")
+
+
+# Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne
+# spectrum.
+def plot_results_as_image_graph():
+    Y, files = torch.load('tsne_output.pth')
+    fig, ax = pyplot.subplots()
+    fig.set_size_inches(200,200,forward=True)
+    ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]]))
+    ax.autoscale()
+
+    for b in tqdm(range(Y.shape[0])):
+        im = pyplot.imread(files[b])
+        im = OffsetImage(im, zoom=1/8)
+        ab = AnnotationBbox(im, (Y[b, 0], Y[b, 1]), xycoords='data', frameon=False)
+        ax.add_artist(ab)
+    ax.scatter(Y[:, 0], Y[:, 1])
+
+    pyplot.savefig('tsne.pdf')
+
+
+if __name__ == "__main__":
+    #run_tsne()
+    plot_results_as_image_graph()
\ No newline at end of file