forked from mrq/DL-Art-School
Spinenet playground
This commit is contained in:
parent
88fc049c8d
commit
c0aeaabc31
|
@ -29,7 +29,7 @@ class ImageFolderDataset:
|
|||
self.weights = opt['weights']
|
||||
|
||||
# Just scan the given directory for images of standard types.
|
||||
supported_types = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG', 'gif', 'GIF']
|
||||
supported_types = ['jpg', 'jpeg', 'png', 'gif']
|
||||
self.image_paths = []
|
||||
for path, weight in zip(self.paths, self.weights):
|
||||
cache_path = os.path.join(path, 'cache.pth')
|
||||
|
|
|
@ -4,8 +4,11 @@ import shutil
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import ToTensor, Resize
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
from data.image_folder_dataset import ImageFolderDataset
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
|
@ -15,12 +18,20 @@ from models.archs.spinenet_arch import SpineNet
|
|||
# and the distance is computed across the channel dimension.
|
||||
def structural_euc_dist(x, y):
|
||||
diff = torch.square(x - y)
|
||||
sum = torch.sum(diff, dim=1)
|
||||
sum = torch.sum(diff, dim=-1)
|
||||
return torch.sqrt(sum)
|
||||
|
||||
|
||||
def cosine_similarity(x, y):
|
||||
return nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
|
||||
x = norm(x)
|
||||
y = norm(y)
|
||||
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
|
||||
|
||||
|
||||
def norm(x):
|
||||
sh = x.shape
|
||||
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
|
||||
return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r)
|
||||
|
||||
|
||||
def im_norm(x):
|
||||
|
@ -30,14 +41,15 @@ def im_norm(x):
|
|||
def get_image_folder_dataloader(batch_size, num_workers):
|
||||
dataset_opt = {
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||
'weights': [1],
|
||||
'target_size': 512,
|
||||
'force_multiple': 32,
|
||||
'scale': 1
|
||||
}
|
||||
dataset = ImageFolderDataset(dataset_opt)
|
||||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
|
||||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
||||
|
||||
|
||||
def create_latent_database(model):
|
||||
|
@ -48,21 +60,121 @@ def create_latent_database(model):
|
|||
os.makedirs(output_path, exist_ok=True)
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||
id = 0
|
||||
dict_count = 1
|
||||
latent_dict = {}
|
||||
all_paths = []
|
||||
for batch in tqdm(dataloader):
|
||||
hq = batch['hq'].to('cuda:1')
|
||||
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
|
||||
for b in range(latent.shape[0]):
|
||||
shutil.copy(batch[b]['HQ_path'], os.path.join(output_path, "%i.jpg" % (id,)))
|
||||
im_path = batch['HQ_path'][b]
|
||||
all_paths.append(im_path)
|
||||
latent_dict[id] = latent[b].detach().cpu()
|
||||
if id % 100 == 0:
|
||||
if (id+1) % 1000 == 0:
|
||||
print("Saving checkpoint..")
|
||||
torch.save(latent_dict, "latent_dict.pth")
|
||||
torch.save(latent_dict, os.path.join(output_path, "latent_dict_%i.pth" % (dict_count,)))
|
||||
latent_dict = {}
|
||||
torch.save(all_paths, os.path.join(output_path, "all_paths.pth"))
|
||||
dict_count += 1
|
||||
id += 1
|
||||
|
||||
|
||||
def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_size):
|
||||
_, c, h, w = latent.shape
|
||||
lat_dict = torch.load(os.path.join(hq_img_repo, ld_file_name))
|
||||
comparables = torch.stack(list(lat_dict.values()), dim=0).permute(0,2,3,1)
|
||||
cbl_shape = comparables.shape[:3]
|
||||
assert cbl_shape[1] == 32
|
||||
comparables = comparables.reshape(-1, c)
|
||||
|
||||
clat = latent.reshape(1,-1,h*w).permute(2,0,1)
|
||||
cpbl_chunked = torch.chunk(comparables, len(comparables) // batch_size)
|
||||
assert len(comparables) % batch_size == 0 # The reconstruction logic doesn't work if this is not the case.
|
||||
mins = []
|
||||
min_offsets = []
|
||||
for cpbl_chunk in tqdm(cpbl_chunked):
|
||||
cpbl_chunk = cpbl_chunk.to('cuda:1')
|
||||
dist = structural_euc_dist(clat, cpbl_chunk.unsqueeze(0))
|
||||
_min = torch.min(dist, dim=-1)
|
||||
mins.append(_min[0])
|
||||
min_offsets.append(_min[1])
|
||||
mins = torch.min(torch.stack(mins, dim=-1), dim=-1)
|
||||
# There's some way to do this in torch, I just can't figure it out..
|
||||
for i in range(len(mins[1])):
|
||||
mins[1][i] = mins[1][i] * batch_size + min_offsets[mins[1][i]][i]
|
||||
|
||||
return mins[0].cpu(), mins[1].cpu(), len(comparables)
|
||||
|
||||
|
||||
def find_similar_latents(model):
|
||||
img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\adrianna_xx.jpg'
|
||||
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
|
||||
hq_img_repo = '../../results/byol_spinenet_latents'
|
||||
output_path = '../../results/byol_spinenet_similars'
|
||||
batch_size = 1024
|
||||
num_maps = 8
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth"))
|
||||
img_t = ToTensor()(Image.open(img)).to('cuda:1').unsqueeze(0)
|
||||
_, _, h, w = img_t.shape
|
||||
img_t = img_t[:, :, :128*(h//128), :128*(w//128)]
|
||||
|
||||
latent = model(img_t)[1]
|
||||
_, c, h, w = latent.shape
|
||||
mins, min_offsets = [], []
|
||||
total_latents = -1
|
||||
for d_id in range(1,num_maps+1):
|
||||
mn, of, tl = _get_mins_from_latent_dictionary(latent, hq_img_repo, "latent_dict_%i.pth" % (d_id), batch_size)
|
||||
if total_latents != -1:
|
||||
assert total_latents == tl
|
||||
else:
|
||||
total_latents = tl
|
||||
mins.append(mn)
|
||||
min_offsets.append(of)
|
||||
mins = torch.min(torch.stack(mins, dim=-1), dim=-1)
|
||||
# There's some way to do this in torch, I just can't figure it out..
|
||||
for i in range(len(mins[1])):
|
||||
mins[1][i] = mins[1][i] * total_latents + min_offsets[mins[1][i]][i]
|
||||
min_ids = mins[1]
|
||||
|
||||
print("Constructing image map..")
|
||||
doc_out = '''
|
||||
<html><body><img id="imgmap" src="source.png" usemap="#map">
|
||||
<map name="map">%s</map><br>
|
||||
<button onclick="if(imgmap.src.includes('output.png')){imgmap.src='source.png';}else{imgmap.src='output.png';}">Swap Images</button>
|
||||
</body></html>
|
||||
'''
|
||||
img_map_areas = []
|
||||
img_out = torch.zeros((1,3,h*16,w*16))
|
||||
for i, ind in enumerate(tqdm(min_ids)):
|
||||
u = np.unravel_index(ind.item(), (num_maps*total_latents//(32*32),32,32))
|
||||
h_, w_ = np.unravel_index(i, (h, w))
|
||||
|
||||
img = ToTensor()(Resize((512, 512))(Image.open(img_bank_paths[u[0]])))
|
||||
t = 16 * u[1]
|
||||
l = 16 * u[2]
|
||||
patch = img[:, t:t+16, l:l+16]
|
||||
img_out[:,:,h_*16:h_*16+16,w_*16:w_*16+16] = patch
|
||||
|
||||
# Also save the image with a masked map
|
||||
mask = torch.full_like(img, fill_value=.3)
|
||||
mask[:, t:t+16, l:l+16] = 1
|
||||
masked_img = img * mask
|
||||
masked_src_img_output_file = os.path.join(output_path, "%i_%i__%i.png" % (t, l, u[0]))
|
||||
torchvision.utils.save_image(masked_img, masked_src_img_output_file)
|
||||
|
||||
# Update the image map areas.
|
||||
img_map_areas.append('<area shape="rect" coords="%i,%i,%i,%i" href="%s">' % (w_*16,h_*16,w_*16+16,h_*16+16,masked_src_img_output_file))
|
||||
torchvision.utils.save_image(img_out, os.path.join(output_path, "output.png"))
|
||||
torchvision.utils.save_image(img_t, os.path.join(output_path, "source.png"))
|
||||
doc_out = doc_out % ('\n'.join(img_map_areas))
|
||||
with open(os.path.join(output_path, 'map.html'), 'w') as f:
|
||||
print(doc_out, file=f)
|
||||
|
||||
|
||||
def explore_latent_results(model):
|
||||
batch_size = 8
|
||||
batch_size = 16
|
||||
num_workers = 1
|
||||
output_path = '../../results/byol_spinenet_explore_latents/'
|
||||
|
||||
|
@ -77,7 +189,7 @@ def explore_latent_results(model):
|
|||
b, c, h, w = latent.shape
|
||||
center = latent[:, :, h//2, w//2].unsqueeze(-1).unsqueeze(-1)
|
||||
centers = center.repeat(1, 1, h, w)
|
||||
dist = structural_euc_dist(latent, centers).unsqueeze(1)
|
||||
dist = cosine_similarity(latent, centers).unsqueeze(1)
|
||||
dist = im_norm(dist)
|
||||
torchvision.utils.save_image(dist, os.path.join(output_path, "%i.png" % id))
|
||||
id += 1
|
||||
|
@ -91,4 +203,4 @@ if __name__ == '__main__':
|
|||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
explore_latent_results(model)
|
||||
find_similar_latents(model)
|
Loading…
Reference in New Issue
Block a user