forked from mrq/DL-Art-School
byol_resnet_playground
Similar to the spinenet playground, but tinkers with resnet instead
This commit is contained in:
parent
8f18b2709e
commit
8de5a02a48
178
codes/scripts/byol_resnet_playground.py
Normal file
178
codes/scripts/byol_resnet_playground.py
Normal file
|
@ -0,0 +1,178 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import ToTensor, Resize
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
import utils
|
||||
from data.image_folder_dataset import ImageFolderDataset
|
||||
from models.resnet_with_checkpointing import resnet50
|
||||
from models.spinenet_arch import SpineNet
|
||||
|
||||
|
||||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||
# and the distance is computed across the channel dimension.
|
||||
from utils import util
|
||||
from utils.options import dict_to_nonedict
|
||||
|
||||
|
||||
def structural_euc_dist(x, y):
|
||||
diff = torch.square(x - y)
|
||||
sum = torch.sum(diff, dim=-1)
|
||||
return torch.mean(torch.sqrt(sum))
|
||||
|
||||
|
||||
def cosine_similarity(x, y):
|
||||
x = norm(x)
|
||||
y = norm(y)
|
||||
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
|
||||
|
||||
|
||||
def key_value_difference(x, y):
|
||||
x = F.normalize(x, dim=-1, p=2)
|
||||
y = F.normalize(y, dim=-1, p=2)
|
||||
return 2 - 2 * (x * y).sum(dim=-1)
|
||||
|
||||
|
||||
def norm(x):
|
||||
sh = x.shape
|
||||
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
|
||||
return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r)
|
||||
|
||||
|
||||
def im_norm(x):
|
||||
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
|
||||
|
||||
|
||||
def get_image_folder_dataloader(batch_size, num_workers):
|
||||
dataset_opt = dict_to_nonedict({
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||
'weights': [1],
|
||||
'target_size': 224,
|
||||
'force_multiple': 32,
|
||||
'scale': 1
|
||||
})
|
||||
dataset = ImageFolderDataset(dataset_opt)
|
||||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
||||
|
||||
|
||||
def _find_layer(net, layer_name):
|
||||
if type(layer_name) == str:
|
||||
modules = dict([*net.named_modules()])
|
||||
return modules.get(layer_name, None)
|
||||
elif type(layer_name) == int:
|
||||
children = [*net.children()]
|
||||
return children[layer_name]
|
||||
return None
|
||||
|
||||
|
||||
layer_hooked_value = None
|
||||
def _hook(_, __, output):
|
||||
global layer_hooked_value
|
||||
layer_hooked_value = output
|
||||
|
||||
|
||||
def register_hook(net, layer_name):
|
||||
layer = _find_layer(net, layer_name)
|
||||
assert layer is not None, f'hidden layer ({self.layer}) not found'
|
||||
layer.register_forward_hook(_hook)
|
||||
|
||||
|
||||
def create_latent_database(model, model_index=0):
|
||||
batch_size = 32
|
||||
num_workers = 1
|
||||
output_path = '../../results/byol_resnet_latents/'
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||
id = 0
|
||||
dict_count = 1
|
||||
latent_dict = {}
|
||||
all_paths = []
|
||||
for batch in tqdm(dataloader):
|
||||
hq = batch['hq'].to('cuda')
|
||||
latent = model(hq)[model_index] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
|
||||
for b in range(latent.shape[0]):
|
||||
im_path = batch['HQ_path'][b]
|
||||
all_paths.append(im_path)
|
||||
latent_dict[id] = latent[b].detach().cpu()
|
||||
if (id+1) % 1000 == 0:
|
||||
print("Saving checkpoint..")
|
||||
torch.save(latent_dict, os.path.join(output_path, "latent_dict_%i.pth" % (dict_count,)))
|
||||
latent_dict = {}
|
||||
torch.save(all_paths, os.path.join(output_path, "all_paths.pth"))
|
||||
dict_count += 1
|
||||
id += 1
|
||||
|
||||
|
||||
|
||||
def get_latent_for_img(model, img):
|
||||
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
|
||||
_, _, h, w = img_t.shape
|
||||
# Center crop img_t and resize to 224.
|
||||
d = min(h, w)
|
||||
dh, dw = (h-d)//2, (w-d)//2
|
||||
if dh == 0:
|
||||
img_t = img_t[:, :, :, dw:-dw]
|
||||
else:
|
||||
img_t = img_t[:, :, dh:-dh, :]
|
||||
img_t = torch.nn.functional.interpolate(img_t, size=(224, 224), mode="area")
|
||||
model(img_t)
|
||||
latent = layer_hooked_value
|
||||
return latent
|
||||
|
||||
|
||||
def find_similar_latents(model, compare_fn=structural_euc_dist):
|
||||
global layer_hooked_value
|
||||
|
||||
img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\yui_xx.jpg'
|
||||
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
|
||||
output_path = '../../results/byol_resnet_similars'
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
imglatent = get_latent_for_img(model, img)
|
||||
_, c, h, w = imglatent.shape
|
||||
|
||||
batch_size = 32
|
||||
num_workers = 1
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||
id = 0
|
||||
results = []
|
||||
for batch in tqdm(dataloader):
|
||||
hq = batch['hq'].to('cuda')
|
||||
model(hq)
|
||||
latent = layer_hooked_value
|
||||
for b in range(latent.shape[0]):
|
||||
im_path = batch['HQ_path'][b]
|
||||
results.append((im_path, compare_fn(imglatent, latent[b].unsqueeze(0)).item()))
|
||||
id += 1
|
||||
if id > 2000:
|
||||
break
|
||||
results.sort(key=lambda x: x[1])
|
||||
for i in range(50):
|
||||
mag = results[i][1]
|
||||
shutil.copy(results[i][0], os.path.join(output_path, f'{i}_{mag}.jpg'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pretrained_path = '../../experiments/resnet_byol_diffframe_69k.pth'
|
||||
model = resnet50(pretrained=False).to('cuda')
|
||||
sd = torch.load(pretrained_path)
|
||||
resnet_sd = {}
|
||||
for k, v in sd.items():
|
||||
if 'target_encoder.net.' in k:
|
||||
resnet_sd[k.replace('target_encoder.net.', '')] = v
|
||||
model.load_state_dict(resnet_sd, strict=True)
|
||||
model.eval()
|
||||
register_hook(model, 'avgpool')
|
||||
|
||||
with torch.no_grad():
|
||||
find_similar_latents(model, structural_euc_dist)
|
Loading…
Reference in New Issue
Block a user