Structural latents checkpoint

This commit is contained in:
James Betker 2020-12-11 12:01:09 -07:00
parent 26ceca68c0
commit ec0ee25f4b
5 changed files with 100 additions and 43 deletions

View File

@ -4,7 +4,7 @@ from time import time
import torch
import torchvision
from torch.utils.data import Dataset
from kornia import augmentation as augs
from kornia import augmentation as augs, kornia
from kornia import filters
import torch.nn as nn
import torch.nn.functional as F
@ -98,10 +98,11 @@ class RandomSharedRegionCrop(nn.Module):
# 2. Pick a random width, height and top corner location for the first patch.
# 3. Pick a random width, height and top corner location for the second patch.
# Note: All dims from (2) and (3) must contain at least half of the image, guaranteeing overlap.
# 6. Build patches from input images. Resize them appropriately. Apply translational jitter.
# 7. Compute the metrics needed to extract overlapping regions from the resized patches: top, left,
# 4. Build patches from input images. Resize them appropriately. Apply translational jitter.\
# 5. Randomly flip image 2 if needed.
# 5. Compute the metrics needed to extract overlapping regions from the resized patches: top, left,
# original_height, original_width.
# 8. Compute the "shared_view" from the above data.
# 6. Compute the "shared_view" from the above data.
# Step 1
c, d, _ = i1.shape
@ -122,7 +123,7 @@ class RandomSharedRegionCrop(nn.Module):
im2_t = random.randint(0, d-im2_h)
im2_r, im2_b = im2_l+im2_w, im2_t+im2_h
# Step 6
# Step 4
m = self.multiple
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
jt = jt if base_t != 0 else abs(jt) # If the top of a patch is zero, a negative jitter will cause it to go negative.
@ -139,14 +140,22 @@ class RandomSharedRegionCrop(nn.Module):
p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl]
p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear")
# Step 7
# Step 5
should_flip = random.random() < .5
if should_flip:
should_flip = 1
p2_resized = kornia.geometry.transform.hflip(p2_resized)
else:
should_flip = 0
# Step 6
i1_shared_t, i1_shared_l = snap(base_t, im2_t), snap(base_l, im2_l)
i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l)
ix_h = min(base_b, im2_b) - max(base_t, im2_t)
ix_w = min(base_r, im2_r) - max(base_l, im2_l)
recompute_package = torch.tensor([base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w], dtype=torch.long)
recompute_package = torch.tensor([base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, should_flip, ix_h, ix_w], dtype=torch.long)
# Step 8
# Step 7
mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5)
mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
masked1 = pad_to(p1 * mask1, d*m)
@ -171,10 +180,14 @@ def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor):
# It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside
# of conforming the recompute_package across the entire batch.
for b in range(package.shape[0]):
f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = tuple(package[b].tolist())
f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, should_flip, s_h, s_w = tuple(package[b].tolist())
# Unflip 2 if needed.
f2 = fea2[b]
if should_flip == 1:
f2 = kornia.geometry.transform.hflip(f2)
# Resize the input features to match
f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="bilinear")
f2s = F.interpolate(fea2[b].unsqueeze(0), (f2_h, f2_w), mode="bilinear")
f2s = F.interpolate(f2.unsqueeze(0), (f2_h, f2_w), mode="bilinear")
# Outputs must be padded so they can "get along" with each other.
res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim))
res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim))

View File

@ -176,3 +176,8 @@ class StructuralBYOL(nn.Module):
loss = loss_one + loss_two
return loss.mean()
def get_projection(self, image):
enc = self.online_encoder(image)
proj = self.online_predictor(enc)
return enc, proj

View File

@ -3,8 +3,8 @@ import torch
from models.archs.spinenet_arch import SpineNet
if __name__ == '__main__':
pretrained_path = '../../experiments/train_byol_512unsupervised/models/117000_generator.pth'
output_path = '../../experiments/spinenet49_imgset_byol.pth'
pretrained_path = '../../experiments/train_sbyol_512unsupervised/models/35000_generator.pth'
output_path = '../../experiments/spinenet49_imgset_sbyol.pth'
wrap_key = 'online_encoder.net.'
sd = torch.load(pretrained_path)

View File

@ -3,6 +3,7 @@ import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from PIL import Image
from torch.utils.data import DataLoader
@ -10,12 +11,16 @@ from torchvision.transforms import ToTensor, Resize
from tqdm import tqdm
import numpy as np
import utils
from data.image_folder_dataset import ImageFolderDataset
from models.archs.spinenet_arch import SpineNet
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
# and the distance is computed across the channel dimension.
from utils import util
def structural_euc_dist(x, y):
diff = torch.square(x - y)
sum = torch.sum(diff, dim=-1)
@ -28,6 +33,12 @@ def cosine_similarity(x, y):
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
def key_value_difference(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
def norm(x):
sh = x.shape
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
@ -41,8 +52,8 @@ def im_norm(x):
def get_image_folder_dataloader(batch_size, num_workers):
dataset_opt = {
'name': 'amalgam',
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'weights': [1],
'target_size': 512,
'force_multiple': 32,
@ -52,7 +63,7 @@ def get_image_folder_dataloader(batch_size, num_workers):
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
def create_latent_database(model):
def create_latent_database(model, model_index=0):
batch_size = 8
num_workers = 1
output_path = '../../results/byol_spinenet_latents/'
@ -65,7 +76,7 @@ def create_latent_database(model):
all_paths = []
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda')
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
latent = model(hq)[model_index] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
for b in range(latent.shape[0]):
im_path = batch['HQ_path'][b]
all_paths.append(im_path)
@ -79,14 +90,8 @@ def create_latent_database(model):
id += 1
def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_size):
def _get_mins_from_comparables(latent, comparables, batch_size, compare_fn):
_, c, h, w = latent.shape
lat_dict = torch.load(os.path.join(hq_img_repo, ld_file_name))
comparables = torch.stack(list(lat_dict.values()), dim=0).permute(0,2,3,1)
cbl_shape = comparables.shape[:3]
assert cbl_shape[1] == 32
comparables = comparables.reshape(-1, c)
clat = latent.reshape(1,-1,h*w).permute(2,0,1)
cpbl_chunked = torch.chunk(comparables, len(comparables) // batch_size)
assert len(comparables) % batch_size == 0 # The reconstruction logic doesn't work if this is not the case.
@ -94,11 +99,12 @@ def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_si
min_offsets = []
for cpbl_chunk in tqdm(cpbl_chunked):
cpbl_chunk = cpbl_chunk.to('cuda')
dist = structural_euc_dist(clat, cpbl_chunk.unsqueeze(0))
dist = compare_fn(clat, cpbl_chunk.unsqueeze(0))
_min = torch.min(dist, dim=-1)
mins.append(_min[0])
min_offsets.append(_min[1])
mins = torch.min(torch.stack(mins, dim=-1), dim=-1)
# There's some way to do this in torch, I just can't figure it out..
for i in range(len(mins[1])):
mins[1][i] = mins[1][i] * batch_size + min_offsets[mins[1][i]][i]
@ -106,26 +112,36 @@ def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_si
return mins[0].cpu(), mins[1].cpu(), len(comparables)
def find_similar_latents(model):
def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_size, compare_fn):
_, c, h, w = latent.shape
lat_dict = torch.load(os.path.join(hq_img_repo, ld_file_name))
comparables = torch.stack(list(lat_dict.values()), dim=0).permute(0,2,3,1)
cbl_shape = comparables.shape[:3]
comparables = comparables.reshape(-1, c)
return _get_mins_from_comparables(latent, comparables, batch_size, compare_fn)
def find_similar_latents(model, model_index=0, lat_patch_size=16, compare_fn=structural_euc_dist):
img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\adrianna_xx.jpg'
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
hq_img_repo = '../../results/byol_spinenet_latents'
output_path = '../../results/byol_spinenet_similars'
batch_size = 1024
num_maps = 8
batch_size = 2048
num_maps = 4
lat_patch_mult = 512 // lat_patch_size
os.makedirs(output_path, exist_ok=True)
img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth"))
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
_, _, h, w = img_t.shape
img_t = img_t[:, :, :128*(h//128), :128*(w//128)]
latent = model(img_t)[1]
latent = model(img_t)[model_index]
_, c, h, w = latent.shape
mins, min_offsets = [], []
total_latents = -1
for d_id in range(1,num_maps+1):
mn, of, tl = _get_mins_from_latent_dictionary(latent, hq_img_repo, "latent_dict_%i.pth" % (d_id), batch_size)
mn, of, tl = _get_mins_from_latent_dictionary(latent, hq_img_repo, "latent_dict_%i.pth" % (d_id), batch_size, compare_fn)
if total_latents != -1:
assert total_latents == tl
else:
@ -140,32 +156,37 @@ def find_similar_latents(model):
print("Constructing image map..")
doc_out = '''
<html><body><img id="imgmap" src="source.png" usemap="#map">
<html><body><img id="imgmap" src="output.png" usemap="#map">
<map name="map">%s</map><br>
<button onclick="if(imgmap.src.includes('output.png')){imgmap.src='source.png';}else{imgmap.src='output.png';}">Swap Images</button>
</body></html>
'''
img_map_areas = []
img_out = torch.zeros((1,3,h*16,w*16))
img_out = torch.zeros((1, 3, h * lat_patch_size, w * lat_patch_size))
for i, ind in enumerate(tqdm(min_ids)):
u = np.unravel_index(ind.item(), (num_maps*total_latents//(32*32),32,32))
u = np.unravel_index(ind.item(), (num_maps * total_latents // (lat_patch_mult ** 2), lat_patch_mult, lat_patch_mult))
h_, w_ = np.unravel_index(i, (h, w))
img = ToTensor()(Resize((512, 512))(Image.open(img_bank_paths[u[0]])))
t = 16 * u[1]
l = 16 * u[2]
patch = img[:, t:t+16, l:l+16]
img_out[:,:,h_*16:h_*16+16,w_*16:w_*16+16] = patch
t = lat_patch_size * u[1]
l = lat_patch_size * u[2]
patch = img[:, t:t + lat_patch_size, l:l + lat_patch_size]
img_out[:,:, h_ * lat_patch_size:h_ * lat_patch_size + lat_patch_size,
w_ * lat_patch_size:w_ * lat_patch_size + lat_patch_size] = patch
# Also save the image with a masked map
mask = torch.full_like(img, fill_value=.3)
mask[:, t:t+16, l:l+16] = 1
mask[:, t:t + lat_patch_size, l:l + lat_patch_size] = 1
masked_img = img * mask
masked_src_img_output_file = os.path.join(output_path, "%i_%i__%i.png" % (t, l, u[0]))
torchvision.utils.save_image(masked_img, masked_src_img_output_file)
# Update the image map areas.
img_map_areas.append('<area shape="rect" coords="%i,%i,%i,%i" href="%s">' % (w_*16,h_*16,w_*16+16,h_*16+16,masked_src_img_output_file))
img_map_areas.append('<area shape="rect" coords="%i,%i,%i,%i" href="%s">' % (w_ * lat_patch_size,
h_ * lat_patch_size,
w_ * lat_patch_size + lat_patch_size,
h_ * lat_patch_size + lat_patch_size,
masked_src_img_output_file))
torchvision.utils.save_image(img_out, os.path.join(output_path, "output.png"))
torchvision.utils.save_image(img_t, os.path.join(output_path, "source.png"))
doc_out = doc_out % ('\n'.join(img_map_areas))
@ -195,12 +216,30 @@ def explore_latent_results(model):
id += 1
if __name__ == '__main__':
pretrained_path = '../../experiments/spinenet49_imgset_byol.pth'
class BYOLModelWrapper(nn.Module):
def __init__(self, wrap):
super().__init__()
self.wrap = wrap
def forward(self, img):
return self.wrap.get_projection(img)
if __name__ == '__main__':
util.loaded_options = {'checkpointing_enabled': True}
pretrained_path = '../../experiments/spinenet49_imgset_sbyol.pth'
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda')
model.load_state_dict(torch.load(pretrained_path), strict=True)
model.eval()
#pretrained_path = '../../experiments/train_sbyol_512unsupervised/models/35000_generator.pth'
#from models.byol.byol_structural import StructuralBYOL
#subnet = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda')
#model = StructuralBYOL(subnet, image_size=256, hidden_layer='endpoint_convs.3.conv')
#model.load_state_dict(torch.load(pretrained_path), strict=True)
#model = BYOLModelWrapper(model)
#model.eval()
with torch.no_grad():
find_similar_latents(model)
#create_latent_database(model, 0) # 0 = model output dimension to use for latent storage
find_similar_latents(model, 0, 8, structural_euc_dist) # 1 = model output dimension to use for latent predictor.

View File

@ -292,7 +292,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_512unsupervised.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_sbyol_512unsupervised.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()