forked from mrq/DL-Art-School
BYOL script updates
This commit is contained in:
parent
de10c7246a
commit
659814c20f
|
@ -53,10 +53,9 @@ def im_norm(x):
|
|||
def get_image_folder_dataloader(batch_size, num_workers):
|
||||
dataset_opt = dict_to_nonedict({
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||
'weights': [1],
|
||||
'target_size': 512,
|
||||
'target_size': 256,
|
||||
'force_multiple': 32,
|
||||
'scale': 1
|
||||
})
|
||||
|
|
|
@ -57,7 +57,7 @@ def im_norm(x):
|
|||
def get_image_folder_dataloader(batch_size, num_workers):
|
||||
dataset_opt = dict_to_nonedict({
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||
'weights': [1],
|
||||
'target_size': 256,
|
||||
|
@ -68,9 +68,28 @@ def get_image_folder_dataloader(batch_size, num_workers):
|
|||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
||||
|
||||
|
||||
def produce_latent_dict(model):
|
||||
batch_size = 32
|
||||
num_workers = 4
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||
id = 0
|
||||
paths = []
|
||||
latents = []
|
||||
for batch in tqdm(dataloader):
|
||||
hq = batch['hq'].to('cuda')
|
||||
l = model(hq).cpu().split(1, dim=0)
|
||||
latents.extend(l)
|
||||
paths.extend(batch['HQ_path'])
|
||||
id += batch_size
|
||||
if id > 1000:
|
||||
print("Saving checkpoint..")
|
||||
torch.save((latents, paths), '../results.pth')
|
||||
id = 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pretrained_path = '../experiments/uresnet_pixpro_83k.pth'
|
||||
model = UResNet50(Bottleneck, [3,4,6,3]).to('cuda')
|
||||
pretrained_path = '../experiments/uresnet_pixpro_attempt2.pth'
|
||||
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda')
|
||||
sd = torch.load(pretrained_path)
|
||||
resnet_sd = {}
|
||||
for k, v in sd.items():
|
||||
|
@ -80,5 +99,6 @@ if __name__ == '__main__':
|
|||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
find_similar_latents(model, 0, 8, structural_euc_dist)
|
||||
#find_similar_latents(model, 0, 8, structural_euc_dist)
|
||||
#create_latent_database(model, batch_size=32)
|
||||
produce_latent_dict(model)
|
||||
|
|
|
@ -258,9 +258,11 @@ def plot_instance_level_results_as_image_graph():
|
|||
pyplot.savefig('tsne.pdf')
|
||||
|
||||
|
||||
random_coords = [(16,16), (14,14), (20,20), (24,24)]
|
||||
random_coords = [(8,8),(12,12),(18,18),(24,24)]
|
||||
def run_tsne_pixel_level():
|
||||
limit = 4000
|
||||
|
||||
''' # For spinenet-style latent dicts
|
||||
latent_dict = torch.load('../results/byol_latents/latent_dict_1.pth')
|
||||
id_vals = list(latent_dict.items())
|
||||
ids, X = zip(*id_vals)
|
||||
|
@ -272,6 +274,22 @@ def run_tsne_pixel_level():
|
|||
for rc in random_coords:
|
||||
X_c.append(X[:, :, rc[0], rc[1]])
|
||||
X = torch.cat(X_c, dim=0)
|
||||
'''
|
||||
|
||||
# For resnet-style latent tuples
|
||||
X, files = torch.load('../results.pth')
|
||||
zipped = list(zip(X, files))
|
||||
shuffle(zipped)
|
||||
X, files = zip(*zipped)
|
||||
|
||||
X = torch.stack(X, dim=0)[:limit//4]
|
||||
# Unravel X into 1 latents per image, chosen from fixed points. This will serve as a psuedorandom source since these
|
||||
# images are not aligned.
|
||||
X_c = []
|
||||
for rc in random_coords:
|
||||
X_c.append(X[:, 0, :, rc[0], rc[1]])
|
||||
X = torch.cat(X_c, dim=0)
|
||||
|
||||
labels = np.zeros(X.shape[0]) # We don't have any labels..
|
||||
|
||||
# confirm that x file get same number point than label file
|
||||
|
@ -295,21 +313,20 @@ def run_tsne_pixel_level():
|
|||
|
||||
pyplot.scatter(Y[:, 0], Y[:, 1], 20, labels)
|
||||
pyplot.show()
|
||||
torch.save((Y, ids[:limit//4]), "../tsne_output_pix.pth")
|
||||
torch.save((Y, files[:limit//4]), "../tsne_output_pix.pth")
|
||||
|
||||
|
||||
# Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne
|
||||
# spectrum.
|
||||
def plot_pixel_level_results_as_image_graph():
|
||||
Y, ids = torch.load('../tsne_output_pix.pth')
|
||||
files = torch.load('../results/byol_latents/all_paths.pth')
|
||||
Y, files = torch.load('../tsne_output_pix.pth')
|
||||
fig, ax = pyplot.subplots()
|
||||
fig.set_size_inches(200,200,forward=True)
|
||||
ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]]))
|
||||
ax.autoscale()
|
||||
|
||||
expansion = 32 # Should be latent_compression(=8) * image_compression_at_inference(=4)
|
||||
margins = 1 # Keep in mind this will be multiplied by <expansion>
|
||||
expansion = 8 # Should be latent_compression(=8) * image_compression_at_inference(=1)
|
||||
margins = 4 # Keep in mind this will be multiplied by <expansion>
|
||||
for b in tqdm(range(Y.shape[0])):
|
||||
if b % 4 == 0:
|
||||
id = b // 4
|
||||
|
|
Loading…
Reference in New Issue
Block a user