From 88fc049c8d9f65be1954a04a7317ed7000d2d668 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 5 Dec 2020 20:30:36 -0700 Subject: [PATCH] spinenet latent playground! --- codes/data/image_corruptor.py | 2 +- codes/data/image_folder_dataset.py | 2 +- codes/models/archs/spinenet_arch.py | 2 +- codes/scripts/byol_spinenet_playground.py | 94 +++++++++++++++++++++++ 4 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 codes/scripts/byol_spinenet_playground.py diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index a613175f..323ddd09 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -9,7 +9,7 @@ from io import BytesIO # options. class ImageCorruptor: def __init__(self, opt): - self.fixed_corruptions = opt['fixed_corruptions'] + self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else [] self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0 if self.num_corrupts == 0: return diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index 31309edc..04aff28c 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -95,7 +95,7 @@ class ImageFolderDataset: hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() - return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]} + return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]} if __name__ == '__main__': diff --git a/codes/models/archs/spinenet_arch.py b/codes/models/archs/spinenet_arch.py index 85fb71dd..3e4924f5 100644 --- a/codes/models/archs/spinenet_arch.py +++ b/codes/models/archs/spinenet_arch.py @@ -209,7 +209,7 @@ class SpineNet(nn.Module): def __init__(self, arch, in_channels=3, - output_level=[3, 4, 5, 6, 7], + output_level=[3, 4], conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True), zero_init_residual=True, diff --git a/codes/scripts/byol_spinenet_playground.py b/codes/scripts/byol_spinenet_playground.py new file mode 100644 index 00000000..498fa1c2 --- /dev/null +++ b/codes/scripts/byol_spinenet_playground.py @@ -0,0 +1,94 @@ +import os +import shutil + +import torch +import torch.nn as nn +import torchvision +from torch.utils.data import DataLoader +from tqdm import tqdm + +from data.image_folder_dataset import ImageFolderDataset +from models.archs.spinenet_arch import SpineNet + + +# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved +# and the distance is computed across the channel dimension. +def structural_euc_dist(x, y): + diff = torch.square(x - y) + sum = torch.sum(diff, dim=1) + return torch.sqrt(sum) + + +def cosine_similarity(x, y): + return nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself. + + +def im_norm(x): + return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5 + + +def get_image_folder_dataloader(batch_size, num_workers): + dataset_opt = { + 'name': 'amalgam', + 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], + 'weights': [1], + 'target_size': 512, + 'force_multiple': 32, + 'scale': 1 + } + dataset = ImageFolderDataset(dataset_opt) + return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + + +def create_latent_database(model): + batch_size = 8 + num_workers = 1 + output_path = '../../results/byol_spinenet_latents/' + + os.makedirs(output_path, exist_ok=True) + dataloader = get_image_folder_dataloader(batch_size, num_workers) + id = 0 + latent_dict = {} + for batch in tqdm(dataloader): + hq = batch['hq'].to('cuda:1') + latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing. + for b in range(latent.shape[0]): + shutil.copy(batch[b]['HQ_path'], os.path.join(output_path, "%i.jpg" % (id,))) + latent_dict[id] = latent[b].detach().cpu() + if id % 100 == 0: + print("Saving checkpoint..") + torch.save(latent_dict, "latent_dict.pth") + id += 1 + + +def explore_latent_results(model): + batch_size = 8 + num_workers = 1 + output_path = '../../results/byol_spinenet_explore_latents/' + + os.makedirs(output_path, exist_ok=True) + dataloader = get_image_folder_dataloader(batch_size, num_workers) + id = 0 + for batch in tqdm(dataloader): + hq = batch['hq'].to('cuda:1') + latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing. + # This operation works by computing the distance of every structural index from the center and using that + # as a "heatmap". + b, c, h, w = latent.shape + center = latent[:, :, h//2, w//2].unsqueeze(-1).unsqueeze(-1) + centers = center.repeat(1, 1, h, w) + dist = structural_euc_dist(latent, centers).unsqueeze(1) + dist = im_norm(dist) + torchvision.utils.save_image(dist, os.path.join(output_path, "%i.png" % id)) + id += 1 + + +if __name__ == '__main__': + pretrained_path = '../../experiments/spinenet49_imgset_byol.pth' + + model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda:1') + model.load_state_dict(torch.load(pretrained_path), strict=True) + model.eval() + + with torch.no_grad(): + explore_latent_results(model) \ No newline at end of file