DL-Art-School/codes/scripts/byol_spinenet_playground.py

94 lines
3.3 KiB
Python
Raw Normal View History

2020-12-06 03:30:36 +00:00
import os
import shutil
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm
from data.image_folder_dataset import ImageFolderDataset
from models.archs.spinenet_arch import SpineNet
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
# and the distance is computed across the channel dimension.
def structural_euc_dist(x, y):
diff = torch.square(x - y)
sum = torch.sum(diff, dim=1)
return torch.sqrt(sum)
def cosine_similarity(x, y):
return nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
def im_norm(x):
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
def get_image_folder_dataloader(batch_size, num_workers):
dataset_opt = {
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
'weights': [1],
'target_size': 512,
'force_multiple': 32,
'scale': 1
}
dataset = ImageFolderDataset(dataset_opt)
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
def create_latent_database(model):
batch_size = 8
num_workers = 1
output_path = '../../results/byol_spinenet_latents/'
os.makedirs(output_path, exist_ok=True)
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
latent_dict = {}
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda:1')
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
for b in range(latent.shape[0]):
shutil.copy(batch[b]['HQ_path'], os.path.join(output_path, "%i.jpg" % (id,)))
latent_dict[id] = latent[b].detach().cpu()
if id % 100 == 0:
print("Saving checkpoint..")
torch.save(latent_dict, "latent_dict.pth")
id += 1
def explore_latent_results(model):
batch_size = 8
num_workers = 1
output_path = '../../results/byol_spinenet_explore_latents/'
os.makedirs(output_path, exist_ok=True)
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda:1')
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
# This operation works by computing the distance of every structural index from the center and using that
# as a "heatmap".
b, c, h, w = latent.shape
center = latent[:, :, h//2, w//2].unsqueeze(-1).unsqueeze(-1)
centers = center.repeat(1, 1, h, w)
dist = structural_euc_dist(latent, centers).unsqueeze(1)
dist = im_norm(dist)
torchvision.utils.save_image(dist, os.path.join(output_path, "%i.png" % id))
id += 1
if __name__ == '__main__':
pretrained_path = '../../experiments/spinenet49_imgset_byol.pth'
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda:1')
model.load_state_dict(torch.load(pretrained_path), strict=True)
model.eval()
with torch.no_grad():
explore_latent_results(model)