diff --git a/codes/data/byol_attachment.py b/codes/data/byol_attachment.py index 0015b41a..55ae1961 100644 --- a/codes/data/byol_attachment.py +++ b/codes/data/byol_attachment.py @@ -4,7 +4,7 @@ from time import time import torch import torchvision from torch.utils.data import Dataset -from kornia import augmentation as augs +from kornia import augmentation as augs, kornia from kornia import filters import torch.nn as nn import torch.nn.functional as F @@ -98,10 +98,11 @@ class RandomSharedRegionCrop(nn.Module): # 2. Pick a random width, height and top corner location for the first patch. # 3. Pick a random width, height and top corner location for the second patch. # Note: All dims from (2) and (3) must contain at least half of the image, guaranteeing overlap. - # 6. Build patches from input images. Resize them appropriately. Apply translational jitter. - # 7. Compute the metrics needed to extract overlapping regions from the resized patches: top, left, + # 4. Build patches from input images. Resize them appropriately. Apply translational jitter.\ + # 5. Randomly flip image 2 if needed. + # 5. Compute the metrics needed to extract overlapping regions from the resized patches: top, left, # original_height, original_width. - # 8. Compute the "shared_view" from the above data. + # 6. Compute the "shared_view" from the above data. # Step 1 c, d, _ = i1.shape @@ -122,7 +123,7 @@ class RandomSharedRegionCrop(nn.Module): im2_t = random.randint(0, d-im2_h) im2_r, im2_b = im2_l+im2_w, im2_t+im2_h - # Step 6 + # Step 4 m = self.multiple jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range) jt = jt if base_t != 0 else abs(jt) # If the top of a patch is zero, a negative jitter will cause it to go negative. @@ -139,14 +140,22 @@ class RandomSharedRegionCrop(nn.Module): p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl] p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear") - # Step 7 + # Step 5 + should_flip = random.random() < .5 + if should_flip: + should_flip = 1 + p2_resized = kornia.geometry.transform.hflip(p2_resized) + else: + should_flip = 0 + + # Step 6 i1_shared_t, i1_shared_l = snap(base_t, im2_t), snap(base_l, im2_l) i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l) ix_h = min(base_b, im2_b) - max(base_t, im2_t) ix_w = min(base_r, im2_r) - max(base_l, im2_l) - recompute_package = torch.tensor([base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w], dtype=torch.long) + recompute_package = torch.tensor([base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, should_flip, ix_h, ix_w], dtype=torch.long) - # Step 8 + # Step 7 mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5) mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1 masked1 = pad_to(p1 * mask1, d*m) @@ -171,10 +180,14 @@ def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor): # It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside # of conforming the recompute_package across the entire batch. for b in range(package.shape[0]): - f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = tuple(package[b].tolist()) + f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, should_flip, s_h, s_w = tuple(package[b].tolist()) + # Unflip 2 if needed. + f2 = fea2[b] + if should_flip == 1: + f2 = kornia.geometry.transform.hflip(f2) # Resize the input features to match f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="bilinear") - f2s = F.interpolate(fea2[b].unsqueeze(0), (f2_h, f2_w), mode="bilinear") + f2s = F.interpolate(f2.unsqueeze(0), (f2_h, f2_w), mode="bilinear") # Outputs must be padded so they can "get along" with each other. res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim)) res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim)) diff --git a/codes/models/byol/byol_structural.py b/codes/models/byol/byol_structural.py index 9d423e60..3ebe5fcf 100644 --- a/codes/models/byol/byol_structural.py +++ b/codes/models/byol/byol_structural.py @@ -176,3 +176,8 @@ class StructuralBYOL(nn.Module): loss = loss_one + loss_two return loss.mean() + + def get_projection(self, image): + enc = self.online_encoder(image) + proj = self.online_predictor(enc) + return enc, proj \ No newline at end of file diff --git a/codes/scripts/byol_extract_wrapped_model.py b/codes/scripts/byol_extract_wrapped_model.py index 6782ac13..545e51d9 100644 --- a/codes/scripts/byol_extract_wrapped_model.py +++ b/codes/scripts/byol_extract_wrapped_model.py @@ -3,8 +3,8 @@ import torch from models.archs.spinenet_arch import SpineNet if __name__ == '__main__': - pretrained_path = '../../experiments/train_byol_512unsupervised/models/117000_generator.pth' - output_path = '../../experiments/spinenet49_imgset_byol.pth' + pretrained_path = '../../experiments/train_sbyol_512unsupervised/models/35000_generator.pth' + output_path = '../../experiments/spinenet49_imgset_sbyol.pth' wrap_key = 'online_encoder.net.' sd = torch.load(pretrained_path) diff --git a/codes/scripts/byol_spinenet_playground.py b/codes/scripts/byol_spinenet_playground.py index 272e0c0d..33558b85 100644 --- a/codes/scripts/byol_spinenet_playground.py +++ b/codes/scripts/byol_spinenet_playground.py @@ -3,6 +3,7 @@ import shutil import torch import torch.nn as nn +import torch.nn.functional as F import torchvision from PIL import Image from torch.utils.data import DataLoader @@ -10,12 +11,16 @@ from torchvision.transforms import ToTensor, Resize from tqdm import tqdm import numpy as np +import utils from data.image_folder_dataset import ImageFolderDataset from models.archs.spinenet_arch import SpineNet # Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved # and the distance is computed across the channel dimension. +from utils import util + + def structural_euc_dist(x, y): diff = torch.square(x - y) sum = torch.sum(diff, dim=-1) @@ -28,6 +33,12 @@ def cosine_similarity(x, y): return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself. +def key_value_difference(x, y): + x = F.normalize(x, dim=-1, p=2) + y = F.normalize(y, dim=-1, p=2) + return 2 - 2 * (x * y).sum(dim=-1) + + def norm(x): sh = x.shape sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))]) @@ -41,8 +52,8 @@ def im_norm(x): def get_image_folder_dataloader(batch_size, num_workers): dataset_opt = { 'name': 'amalgam', - #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], - 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], + 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], + #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], 'weights': [1], 'target_size': 512, 'force_multiple': 32, @@ -52,7 +63,7 @@ def get_image_folder_dataloader(batch_size, num_workers): return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) -def create_latent_database(model): +def create_latent_database(model, model_index=0): batch_size = 8 num_workers = 1 output_path = '../../results/byol_spinenet_latents/' @@ -65,7 +76,7 @@ def create_latent_database(model): all_paths = [] for batch in tqdm(dataloader): hq = batch['hq'].to('cuda') - latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing. + latent = model(hq)[model_index] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing. for b in range(latent.shape[0]): im_path = batch['HQ_path'][b] all_paths.append(im_path) @@ -79,14 +90,8 @@ def create_latent_database(model): id += 1 -def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_size): +def _get_mins_from_comparables(latent, comparables, batch_size, compare_fn): _, c, h, w = latent.shape - lat_dict = torch.load(os.path.join(hq_img_repo, ld_file_name)) - comparables = torch.stack(list(lat_dict.values()), dim=0).permute(0,2,3,1) - cbl_shape = comparables.shape[:3] - assert cbl_shape[1] == 32 - comparables = comparables.reshape(-1, c) - clat = latent.reshape(1,-1,h*w).permute(2,0,1) cpbl_chunked = torch.chunk(comparables, len(comparables) // batch_size) assert len(comparables) % batch_size == 0 # The reconstruction logic doesn't work if this is not the case. @@ -94,11 +99,12 @@ def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_si min_offsets = [] for cpbl_chunk in tqdm(cpbl_chunked): cpbl_chunk = cpbl_chunk.to('cuda') - dist = structural_euc_dist(clat, cpbl_chunk.unsqueeze(0)) + dist = compare_fn(clat, cpbl_chunk.unsqueeze(0)) _min = torch.min(dist, dim=-1) mins.append(_min[0]) min_offsets.append(_min[1]) mins = torch.min(torch.stack(mins, dim=-1), dim=-1) + # There's some way to do this in torch, I just can't figure it out.. for i in range(len(mins[1])): mins[1][i] = mins[1][i] * batch_size + min_offsets[mins[1][i]][i] @@ -106,26 +112,36 @@ def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_si return mins[0].cpu(), mins[1].cpu(), len(comparables) -def find_similar_latents(model): +def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_size, compare_fn): + _, c, h, w = latent.shape + lat_dict = torch.load(os.path.join(hq_img_repo, ld_file_name)) + comparables = torch.stack(list(lat_dict.values()), dim=0).permute(0,2,3,1) + cbl_shape = comparables.shape[:3] + comparables = comparables.reshape(-1, c) + return _get_mins_from_comparables(latent, comparables, batch_size, compare_fn) + + +def find_similar_latents(model, model_index=0, lat_patch_size=16, compare_fn=structural_euc_dist): img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\adrianna_xx.jpg' #img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg' hq_img_repo = '../../results/byol_spinenet_latents' output_path = '../../results/byol_spinenet_similars' - batch_size = 1024 - num_maps = 8 + batch_size = 2048 + num_maps = 4 + lat_patch_mult = 512 // lat_patch_size os.makedirs(output_path, exist_ok=True) img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth")) img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0) _, _, h, w = img_t.shape img_t = img_t[:, :, :128*(h//128), :128*(w//128)] - - latent = model(img_t)[1] + latent = model(img_t)[model_index] _, c, h, w = latent.shape + mins, min_offsets = [], [] total_latents = -1 for d_id in range(1,num_maps+1): - mn, of, tl = _get_mins_from_latent_dictionary(latent, hq_img_repo, "latent_dict_%i.pth" % (d_id), batch_size) + mn, of, tl = _get_mins_from_latent_dictionary(latent, hq_img_repo, "latent_dict_%i.pth" % (d_id), batch_size, compare_fn) if total_latents != -1: assert total_latents == tl else: @@ -140,32 +156,37 @@ def find_similar_latents(model): print("Constructing image map..") doc_out = ''' - + %s
''' img_map_areas = [] - img_out = torch.zeros((1,3,h*16,w*16)) + img_out = torch.zeros((1, 3, h * lat_patch_size, w * lat_patch_size)) for i, ind in enumerate(tqdm(min_ids)): - u = np.unravel_index(ind.item(), (num_maps*total_latents//(32*32),32,32)) + u = np.unravel_index(ind.item(), (num_maps * total_latents // (lat_patch_mult ** 2), lat_patch_mult, lat_patch_mult)) h_, w_ = np.unravel_index(i, (h, w)) img = ToTensor()(Resize((512, 512))(Image.open(img_bank_paths[u[0]]))) - t = 16 * u[1] - l = 16 * u[2] - patch = img[:, t:t+16, l:l+16] - img_out[:,:,h_*16:h_*16+16,w_*16:w_*16+16] = patch + t = lat_patch_size * u[1] + l = lat_patch_size * u[2] + patch = img[:, t:t + lat_patch_size, l:l + lat_patch_size] + img_out[:,:, h_ * lat_patch_size:h_ * lat_patch_size + lat_patch_size, + w_ * lat_patch_size:w_ * lat_patch_size + lat_patch_size] = patch # Also save the image with a masked map mask = torch.full_like(img, fill_value=.3) - mask[:, t:t+16, l:l+16] = 1 + mask[:, t:t + lat_patch_size, l:l + lat_patch_size] = 1 masked_img = img * mask masked_src_img_output_file = os.path.join(output_path, "%i_%i__%i.png" % (t, l, u[0])) torchvision.utils.save_image(masked_img, masked_src_img_output_file) # Update the image map areas. - img_map_areas.append('' % (w_*16,h_*16,w_*16+16,h_*16+16,masked_src_img_output_file)) + img_map_areas.append('' % (w_ * lat_patch_size, + h_ * lat_patch_size, + w_ * lat_patch_size + lat_patch_size, + h_ * lat_patch_size + lat_patch_size, + masked_src_img_output_file)) torchvision.utils.save_image(img_out, os.path.join(output_path, "output.png")) torchvision.utils.save_image(img_t, os.path.join(output_path, "source.png")) doc_out = doc_out % ('\n'.join(img_map_areas)) @@ -195,12 +216,30 @@ def explore_latent_results(model): id += 1 -if __name__ == '__main__': - pretrained_path = '../../experiments/spinenet49_imgset_byol.pth' +class BYOLModelWrapper(nn.Module): + def __init__(self, wrap): + super().__init__() + self.wrap = wrap + def forward(self, img): + return self.wrap.get_projection(img) + + +if __name__ == '__main__': + util.loaded_options = {'checkpointing_enabled': True} + pretrained_path = '../../experiments/spinenet49_imgset_sbyol.pth' model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda') model.load_state_dict(torch.load(pretrained_path), strict=True) model.eval() + #pretrained_path = '../../experiments/train_sbyol_512unsupervised/models/35000_generator.pth' + #from models.byol.byol_structural import StructuralBYOL + #subnet = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda') + #model = StructuralBYOL(subnet, image_size=256, hidden_layer='endpoint_convs.3.conv') + #model.load_state_dict(torch.load(pretrained_path), strict=True) + #model = BYOLModelWrapper(model) + #model.eval() + with torch.no_grad(): - find_similar_latents(model) \ No newline at end of file + #create_latent_database(model, 0) # 0 = model output dimension to use for latent storage + find_similar_latents(model, 0, 8, structural_euc_dist) # 1 = model output dimension to use for latent predictor. diff --git a/codes/train.py b/codes/train.py index 02695bd0..669dd90e 100644 --- a/codes/train.py +++ b/codes/train.py @@ -292,7 +292,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_512unsupervised.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_sbyol_512unsupervised.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()