BYOL script updates

This commit is contained in:
James Betker 2021-01-07 16:31:28 -07:00
parent de10c7246a
commit 659814c20f
3 changed files with 49 additions and 13 deletions

View File

@ -53,10 +53,9 @@ def im_norm(x):
def get_image_folder_dataloader(batch_size, num_workers):
dataset_opt = dict_to_nonedict({
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
'weights': [1],
'target_size': 512,
'target_size': 256,
'force_multiple': 32,
'scale': 1
})

View File

@ -57,7 +57,7 @@ def im_norm(x):
def get_image_folder_dataloader(batch_size, num_workers):
dataset_opt = dict_to_nonedict({
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'weights': [1],
'target_size': 256,
@ -68,9 +68,28 @@ def get_image_folder_dataloader(batch_size, num_workers):
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
def produce_latent_dict(model):
batch_size = 32
num_workers = 4
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
paths = []
latents = []
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda')
l = model(hq).cpu().split(1, dim=0)
latents.extend(l)
paths.extend(batch['HQ_path'])
id += batch_size
if id > 1000:
print("Saving checkpoint..")
torch.save((latents, paths), '../results.pth')
id = 0
if __name__ == '__main__':
pretrained_path = '../experiments/uresnet_pixpro_83k.pth'
model = UResNet50(Bottleneck, [3,4,6,3]).to('cuda')
pretrained_path = '../experiments/uresnet_pixpro_attempt2.pth'
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda')
sd = torch.load(pretrained_path)
resnet_sd = {}
for k, v in sd.items():
@ -80,5 +99,6 @@ if __name__ == '__main__':
model.eval()
with torch.no_grad():
find_similar_latents(model, 0, 8, structural_euc_dist)
#find_similar_latents(model, 0, 8, structural_euc_dist)
#create_latent_database(model, batch_size=32)
produce_latent_dict(model)

View File

@ -258,9 +258,11 @@ def plot_instance_level_results_as_image_graph():
pyplot.savefig('tsne.pdf')
random_coords = [(16,16), (14,14), (20,20), (24,24)]
random_coords = [(8,8),(12,12),(18,18),(24,24)]
def run_tsne_pixel_level():
limit = 4000
''' # For spinenet-style latent dicts
latent_dict = torch.load('../results/byol_latents/latent_dict_1.pth')
id_vals = list(latent_dict.items())
ids, X = zip(*id_vals)
@ -272,6 +274,22 @@ def run_tsne_pixel_level():
for rc in random_coords:
X_c.append(X[:, :, rc[0], rc[1]])
X = torch.cat(X_c, dim=0)
'''
# For resnet-style latent tuples
X, files = torch.load('../results.pth')
zipped = list(zip(X, files))
shuffle(zipped)
X, files = zip(*zipped)
X = torch.stack(X, dim=0)[:limit//4]
# Unravel X into 1 latents per image, chosen from fixed points. This will serve as a psuedorandom source since these
# images are not aligned.
X_c = []
for rc in random_coords:
X_c.append(X[:, 0, :, rc[0], rc[1]])
X = torch.cat(X_c, dim=0)
labels = np.zeros(X.shape[0]) # We don't have any labels..
# confirm that x file get same number point than label file
@ -295,21 +313,20 @@ def run_tsne_pixel_level():
pyplot.scatter(Y[:, 0], Y[:, 1], 20, labels)
pyplot.show()
torch.save((Y, ids[:limit//4]), "../tsne_output_pix.pth")
torch.save((Y, files[:limit//4]), "../tsne_output_pix.pth")
# Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne
# spectrum.
def plot_pixel_level_results_as_image_graph():
Y, ids = torch.load('../tsne_output_pix.pth')
files = torch.load('../results/byol_latents/all_paths.pth')
Y, files = torch.load('../tsne_output_pix.pth')
fig, ax = pyplot.subplots()
fig.set_size_inches(200,200,forward=True)
ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]]))
ax.autoscale()
expansion = 32 # Should be latent_compression(=8) * image_compression_at_inference(=4)
margins = 1 # Keep in mind this will be multiplied by <expansion>
expansion = 8 # Should be latent_compression(=8) * image_compression_at_inference(=1)
margins = 4 # Keep in mind this will be multiplied by <expansion>
for b in tqdm(range(Y.shape[0])):
if b % 4 == 0:
id = b // 4