From 659814c20ff507773910574247424893644cca7e Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 7 Jan 2021 16:31:28 -0700 Subject: [PATCH] BYOL script updates --- .../scripts/byol/byol_spinenet_playground.py | 5 ++-- codes/scripts/byol/byol_uresnet_playground.py | 28 +++++++++++++++--- codes/scripts/byol/tsne_torch.py | 29 +++++++++++++++---- 3 files changed, 49 insertions(+), 13 deletions(-) diff --git a/codes/scripts/byol/byol_spinenet_playground.py b/codes/scripts/byol/byol_spinenet_playground.py index f51a1c07..31d0c621 100644 --- a/codes/scripts/byol/byol_spinenet_playground.py +++ b/codes/scripts/byol/byol_spinenet_playground.py @@ -53,10 +53,9 @@ def im_norm(x): def get_image_folder_dataloader(batch_size, num_workers): dataset_opt = dict_to_nonedict({ 'name': 'amalgam', - 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], - #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], + 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], 'weights': [1], - 'target_size': 512, + 'target_size': 256, 'force_multiple': 32, 'scale': 1 }) diff --git a/codes/scripts/byol/byol_uresnet_playground.py b/codes/scripts/byol/byol_uresnet_playground.py index 50ff1b28..7296dc32 100644 --- a/codes/scripts/byol/byol_uresnet_playground.py +++ b/codes/scripts/byol/byol_uresnet_playground.py @@ -57,7 +57,7 @@ def im_norm(x): def get_image_folder_dataloader(batch_size, num_workers): dataset_opt = dict_to_nonedict({ 'name': 'amalgam', - 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], + 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], 'weights': [1], 'target_size': 256, @@ -68,9 +68,28 @@ def get_image_folder_dataloader(batch_size, num_workers): return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) +def produce_latent_dict(model): + batch_size = 32 + num_workers = 4 + dataloader = get_image_folder_dataloader(batch_size, num_workers) + id = 0 + paths = [] + latents = [] + for batch in tqdm(dataloader): + hq = batch['hq'].to('cuda') + l = model(hq).cpu().split(1, dim=0) + latents.extend(l) + paths.extend(batch['HQ_path']) + id += batch_size + if id > 1000: + print("Saving checkpoint..") + torch.save((latents, paths), '../results.pth') + id = 0 + + if __name__ == '__main__': - pretrained_path = '../experiments/uresnet_pixpro_83k.pth' - model = UResNet50(Bottleneck, [3,4,6,3]).to('cuda') + pretrained_path = '../experiments/uresnet_pixpro_attempt2.pth' + model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda') sd = torch.load(pretrained_path) resnet_sd = {} for k, v in sd.items(): @@ -80,5 +99,6 @@ if __name__ == '__main__': model.eval() with torch.no_grad(): - find_similar_latents(model, 0, 8, structural_euc_dist) + #find_similar_latents(model, 0, 8, structural_euc_dist) #create_latent_database(model, batch_size=32) + produce_latent_dict(model) diff --git a/codes/scripts/byol/tsne_torch.py b/codes/scripts/byol/tsne_torch.py index e67ea3c7..503f4d65 100644 --- a/codes/scripts/byol/tsne_torch.py +++ b/codes/scripts/byol/tsne_torch.py @@ -258,9 +258,11 @@ def plot_instance_level_results_as_image_graph(): pyplot.savefig('tsne.pdf') -random_coords = [(16,16), (14,14), (20,20), (24,24)] +random_coords = [(8,8),(12,12),(18,18),(24,24)] def run_tsne_pixel_level(): limit = 4000 + + ''' # For spinenet-style latent dicts latent_dict = torch.load('../results/byol_latents/latent_dict_1.pth') id_vals = list(latent_dict.items()) ids, X = zip(*id_vals) @@ -272,6 +274,22 @@ def run_tsne_pixel_level(): for rc in random_coords: X_c.append(X[:, :, rc[0], rc[1]]) X = torch.cat(X_c, dim=0) + ''' + + # For resnet-style latent tuples + X, files = torch.load('../results.pth') + zipped = list(zip(X, files)) + shuffle(zipped) + X, files = zip(*zipped) + + X = torch.stack(X, dim=0)[:limit//4] + # Unravel X into 1 latents per image, chosen from fixed points. This will serve as a psuedorandom source since these + # images are not aligned. + X_c = [] + for rc in random_coords: + X_c.append(X[:, 0, :, rc[0], rc[1]]) + X = torch.cat(X_c, dim=0) + labels = np.zeros(X.shape[0]) # We don't have any labels.. # confirm that x file get same number point than label file @@ -295,21 +313,20 @@ def run_tsne_pixel_level(): pyplot.scatter(Y[:, 0], Y[:, 1], 20, labels) pyplot.show() - torch.save((Y, ids[:limit//4]), "../tsne_output_pix.pth") + torch.save((Y, files[:limit//4]), "../tsne_output_pix.pth") # Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne # spectrum. def plot_pixel_level_results_as_image_graph(): - Y, ids = torch.load('../tsne_output_pix.pth') - files = torch.load('../results/byol_latents/all_paths.pth') + Y, files = torch.load('../tsne_output_pix.pth') fig, ax = pyplot.subplots() fig.set_size_inches(200,200,forward=True) ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]])) ax.autoscale() - expansion = 32 # Should be latent_compression(=8) * image_compression_at_inference(=4) - margins = 1 # Keep in mind this will be multiplied by + expansion = 8 # Should be latent_compression(=8) * image_compression_at_inference(=1) + margins = 4 # Keep in mind this will be multiplied by for b in tqdm(range(Y.shape[0])): if b % 4 == 0: id = b // 4