forked from mrq/DL-Art-School
BYOL script updates
This commit is contained in:
parent
de10c7246a
commit
659814c20f
|
@ -53,10 +53,9 @@ def im_norm(x):
|
||||||
def get_image_folder_dataloader(batch_size, num_workers):
|
def get_image_folder_dataloader(batch_size, num_workers):
|
||||||
dataset_opt = dict_to_nonedict({
|
dataset_opt = dict_to_nonedict({
|
||||||
'name': 'amalgam',
|
'name': 'amalgam',
|
||||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
|
||||||
'weights': [1],
|
'weights': [1],
|
||||||
'target_size': 512,
|
'target_size': 256,
|
||||||
'force_multiple': 32,
|
'force_multiple': 32,
|
||||||
'scale': 1
|
'scale': 1
|
||||||
})
|
})
|
||||||
|
|
|
@ -57,7 +57,7 @@ def im_norm(x):
|
||||||
def get_image_folder_dataloader(batch_size, num_workers):
|
def get_image_folder_dataloader(batch_size, num_workers):
|
||||||
dataset_opt = dict_to_nonedict({
|
dataset_opt = dict_to_nonedict({
|
||||||
'name': 'amalgam',
|
'name': 'amalgam',
|
||||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||||
'weights': [1],
|
'weights': [1],
|
||||||
'target_size': 256,
|
'target_size': 256,
|
||||||
|
@ -68,9 +68,28 @@ def get_image_folder_dataloader(batch_size, num_workers):
|
||||||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
||||||
|
|
||||||
|
|
||||||
|
def produce_latent_dict(model):
|
||||||
|
batch_size = 32
|
||||||
|
num_workers = 4
|
||||||
|
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||||
|
id = 0
|
||||||
|
paths = []
|
||||||
|
latents = []
|
||||||
|
for batch in tqdm(dataloader):
|
||||||
|
hq = batch['hq'].to('cuda')
|
||||||
|
l = model(hq).cpu().split(1, dim=0)
|
||||||
|
latents.extend(l)
|
||||||
|
paths.extend(batch['HQ_path'])
|
||||||
|
id += batch_size
|
||||||
|
if id > 1000:
|
||||||
|
print("Saving checkpoint..")
|
||||||
|
torch.save((latents, paths), '../results.pth')
|
||||||
|
id = 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pretrained_path = '../experiments/uresnet_pixpro_83k.pth'
|
pretrained_path = '../experiments/uresnet_pixpro_attempt2.pth'
|
||||||
model = UResNet50(Bottleneck, [3,4,6,3]).to('cuda')
|
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda')
|
||||||
sd = torch.load(pretrained_path)
|
sd = torch.load(pretrained_path)
|
||||||
resnet_sd = {}
|
resnet_sd = {}
|
||||||
for k, v in sd.items():
|
for k, v in sd.items():
|
||||||
|
@ -80,5 +99,6 @@ if __name__ == '__main__':
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
find_similar_latents(model, 0, 8, structural_euc_dist)
|
#find_similar_latents(model, 0, 8, structural_euc_dist)
|
||||||
#create_latent_database(model, batch_size=32)
|
#create_latent_database(model, batch_size=32)
|
||||||
|
produce_latent_dict(model)
|
||||||
|
|
|
@ -258,9 +258,11 @@ def plot_instance_level_results_as_image_graph():
|
||||||
pyplot.savefig('tsne.pdf')
|
pyplot.savefig('tsne.pdf')
|
||||||
|
|
||||||
|
|
||||||
random_coords = [(16,16), (14,14), (20,20), (24,24)]
|
random_coords = [(8,8),(12,12),(18,18),(24,24)]
|
||||||
def run_tsne_pixel_level():
|
def run_tsne_pixel_level():
|
||||||
limit = 4000
|
limit = 4000
|
||||||
|
|
||||||
|
''' # For spinenet-style latent dicts
|
||||||
latent_dict = torch.load('../results/byol_latents/latent_dict_1.pth')
|
latent_dict = torch.load('../results/byol_latents/latent_dict_1.pth')
|
||||||
id_vals = list(latent_dict.items())
|
id_vals = list(latent_dict.items())
|
||||||
ids, X = zip(*id_vals)
|
ids, X = zip(*id_vals)
|
||||||
|
@ -272,6 +274,22 @@ def run_tsne_pixel_level():
|
||||||
for rc in random_coords:
|
for rc in random_coords:
|
||||||
X_c.append(X[:, :, rc[0], rc[1]])
|
X_c.append(X[:, :, rc[0], rc[1]])
|
||||||
X = torch.cat(X_c, dim=0)
|
X = torch.cat(X_c, dim=0)
|
||||||
|
'''
|
||||||
|
|
||||||
|
# For resnet-style latent tuples
|
||||||
|
X, files = torch.load('../results.pth')
|
||||||
|
zipped = list(zip(X, files))
|
||||||
|
shuffle(zipped)
|
||||||
|
X, files = zip(*zipped)
|
||||||
|
|
||||||
|
X = torch.stack(X, dim=0)[:limit//4]
|
||||||
|
# Unravel X into 1 latents per image, chosen from fixed points. This will serve as a psuedorandom source since these
|
||||||
|
# images are not aligned.
|
||||||
|
X_c = []
|
||||||
|
for rc in random_coords:
|
||||||
|
X_c.append(X[:, 0, :, rc[0], rc[1]])
|
||||||
|
X = torch.cat(X_c, dim=0)
|
||||||
|
|
||||||
labels = np.zeros(X.shape[0]) # We don't have any labels..
|
labels = np.zeros(X.shape[0]) # We don't have any labels..
|
||||||
|
|
||||||
# confirm that x file get same number point than label file
|
# confirm that x file get same number point than label file
|
||||||
|
@ -295,21 +313,20 @@ def run_tsne_pixel_level():
|
||||||
|
|
||||||
pyplot.scatter(Y[:, 0], Y[:, 1], 20, labels)
|
pyplot.scatter(Y[:, 0], Y[:, 1], 20, labels)
|
||||||
pyplot.show()
|
pyplot.show()
|
||||||
torch.save((Y, ids[:limit//4]), "../tsne_output_pix.pth")
|
torch.save((Y, files[:limit//4]), "../tsne_output_pix.pth")
|
||||||
|
|
||||||
|
|
||||||
# Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne
|
# Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne
|
||||||
# spectrum.
|
# spectrum.
|
||||||
def plot_pixel_level_results_as_image_graph():
|
def plot_pixel_level_results_as_image_graph():
|
||||||
Y, ids = torch.load('../tsne_output_pix.pth')
|
Y, files = torch.load('../tsne_output_pix.pth')
|
||||||
files = torch.load('../results/byol_latents/all_paths.pth')
|
|
||||||
fig, ax = pyplot.subplots()
|
fig, ax = pyplot.subplots()
|
||||||
fig.set_size_inches(200,200,forward=True)
|
fig.set_size_inches(200,200,forward=True)
|
||||||
ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]]))
|
ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]]))
|
||||||
ax.autoscale()
|
ax.autoscale()
|
||||||
|
|
||||||
expansion = 32 # Should be latent_compression(=8) * image_compression_at_inference(=4)
|
expansion = 8 # Should be latent_compression(=8) * image_compression_at_inference(=1)
|
||||||
margins = 1 # Keep in mind this will be multiplied by <expansion>
|
margins = 4 # Keep in mind this will be multiplied by <expansion>
|
||||||
for b in tqdm(range(Y.shape[0])):
|
for b in tqdm(range(Y.shape[0])):
|
||||||
if b % 4 == 0:
|
if b % 4 == 0:
|
||||||
id = b // 4
|
id = b // 4
|
||||||
|
|
Loading…
Reference in New Issue
Block a user