forked from mrq/DL-Art-School
spinenet latent playground!
This commit is contained in:
parent
20a09cb31b
commit
88fc049c8d
|
@ -9,7 +9,7 @@ from io import BytesIO
|
||||||
# options.
|
# options.
|
||||||
class ImageCorruptor:
|
class ImageCorruptor:
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
self.fixed_corruptions = opt['fixed_corruptions']
|
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
|
||||||
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
|
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
|
||||||
if self.num_corrupts == 0:
|
if self.num_corrupts == 0:
|
||||||
return
|
return
|
||||||
|
|
|
@ -95,7 +95,7 @@ class ImageFolderDataset:
|
||||||
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
|
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
|
||||||
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
|
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
|
||||||
|
|
||||||
return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]}
|
return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -209,7 +209,7 @@ class SpineNet(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
arch,
|
arch,
|
||||||
in_channels=3,
|
in_channels=3,
|
||||||
output_level=[3, 4, 5, 6, 7],
|
output_level=[3, 4],
|
||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN', requires_grad=True),
|
norm_cfg=dict(type='BN', requires_grad=True),
|
||||||
zero_init_residual=True,
|
zero_init_residual=True,
|
||||||
|
|
94
codes/scripts/byol_spinenet_playground.py
Normal file
94
codes/scripts/byol_spinenet_playground.py
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from data.image_folder_dataset import ImageFolderDataset
|
||||||
|
from models.archs.spinenet_arch import SpineNet
|
||||||
|
|
||||||
|
|
||||||
|
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||||
|
# and the distance is computed across the channel dimension.
|
||||||
|
def structural_euc_dist(x, y):
|
||||||
|
diff = torch.square(x - y)
|
||||||
|
sum = torch.sum(diff, dim=1)
|
||||||
|
return torch.sqrt(sum)
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(x, y):
|
||||||
|
return nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
|
||||||
|
|
||||||
|
|
||||||
|
def im_norm(x):
|
||||||
|
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_folder_dataloader(batch_size, num_workers):
|
||||||
|
dataset_opt = {
|
||||||
|
'name': 'amalgam',
|
||||||
|
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||||
|
'weights': [1],
|
||||||
|
'target_size': 512,
|
||||||
|
'force_multiple': 32,
|
||||||
|
'scale': 1
|
||||||
|
}
|
||||||
|
dataset = ImageFolderDataset(dataset_opt)
|
||||||
|
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
|
||||||
|
|
||||||
|
|
||||||
|
def create_latent_database(model):
|
||||||
|
batch_size = 8
|
||||||
|
num_workers = 1
|
||||||
|
output_path = '../../results/byol_spinenet_latents/'
|
||||||
|
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||||
|
id = 0
|
||||||
|
latent_dict = {}
|
||||||
|
for batch in tqdm(dataloader):
|
||||||
|
hq = batch['hq'].to('cuda:1')
|
||||||
|
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
|
||||||
|
for b in range(latent.shape[0]):
|
||||||
|
shutil.copy(batch[b]['HQ_path'], os.path.join(output_path, "%i.jpg" % (id,)))
|
||||||
|
latent_dict[id] = latent[b].detach().cpu()
|
||||||
|
if id % 100 == 0:
|
||||||
|
print("Saving checkpoint..")
|
||||||
|
torch.save(latent_dict, "latent_dict.pth")
|
||||||
|
id += 1
|
||||||
|
|
||||||
|
|
||||||
|
def explore_latent_results(model):
|
||||||
|
batch_size = 8
|
||||||
|
num_workers = 1
|
||||||
|
output_path = '../../results/byol_spinenet_explore_latents/'
|
||||||
|
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||||
|
id = 0
|
||||||
|
for batch in tqdm(dataloader):
|
||||||
|
hq = batch['hq'].to('cuda:1')
|
||||||
|
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
|
||||||
|
# This operation works by computing the distance of every structural index from the center and using that
|
||||||
|
# as a "heatmap".
|
||||||
|
b, c, h, w = latent.shape
|
||||||
|
center = latent[:, :, h//2, w//2].unsqueeze(-1).unsqueeze(-1)
|
||||||
|
centers = center.repeat(1, 1, h, w)
|
||||||
|
dist = structural_euc_dist(latent, centers).unsqueeze(1)
|
||||||
|
dist = im_norm(dist)
|
||||||
|
torchvision.utils.save_image(dist, os.path.join(output_path, "%i.png" % id))
|
||||||
|
id += 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pretrained_path = '../../experiments/spinenet49_imgset_byol.pth'
|
||||||
|
|
||||||
|
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda:1')
|
||||||
|
model.load_state_dict(torch.load(pretrained_path), strict=True)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
explore_latent_results(model)
|
Loading…
Reference in New Issue
Block a user