diff --git a/codes/models/byol/byol_for_semantic_chaining.py b/codes/models/byol/byol_for_semantic_chaining.py index 8de0846e..a8b4fbdf 100644 --- a/codes/models/byol/byol_for_semantic_chaining.py +++ b/codes/models/byol/byol_for_semantic_chaining.py @@ -208,21 +208,21 @@ class NetWrapper(nn.Module): projector = MLP(dim, self.projection_size, self.projection_hidden_size) return projector.to(hidden) - def get_representation(self, x, pt): + def get_representation(self, **kwargs): if self.layer == -1: - return self.net(x, pt) + return self.net(**kwargs) if not self.hook_registered: self._register_hook() - unused = self.net(x, pt) + unused = self.net(**kwargs) hidden = self.hidden self.hidden = None assert hidden is not None, f'hidden layer {self.layer} never emitted an output' return hidden - def forward(self, x, pt): - representation = self.get_representation(x, pt) + def forward(self, **kwargs): + representation = self.get_representation(**kwargs) projector = self._get_projector(representation) projection = checkpoint(projector, representation) return projection @@ -239,6 +239,7 @@ class BYOL(nn.Module): moving_average_decay=0.99, use_momentum=True, structural_mlp=False, + contrastive=False, ): super().__init__() @@ -247,6 +248,7 @@ class BYOL(nn.Module): self.aug = PointwiseAugmentor(image_size) self.use_momentum = use_momentum + self.contrastive = contrastive self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) @@ -278,13 +280,17 @@ class BYOL(nn.Module): def get_debug_values(self, step, __): # In the BYOL paper, this is made to increase over time. Not yet implemented, but still logging the value. - return {'target_ema_beta': self.target_ema_updater.beta} + dbg = {'target_ema_beta': self.target_ema_updater.beta} + if self.contrastive and hasattr(self, 'logs_closs'): + dbg['contrastive_distance'] = self.logs_closs + dbg['byol_distance'] = self.logs_loss + return dbg def visual_dbg(self, step, path): torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,))) torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,))) - def forward(self, image): + def get_predictions_and_projections(self, image): _, _, h, w = image.shape point = torch.randint(h//8, 7*h//8, (2,)).long().to(image.device) @@ -297,16 +303,20 @@ class BYOL(nn.Module): self.im2 = image_two.detach().clone() self.im2[:,:,pt_two[0]-3:pt_two[0]+3,pt_two[1]-3:pt_two[1]+3] = 1 - online_proj_one = self.online_encoder(image_one, pt_one) - online_proj_two = self.online_encoder(image_two, pt_two) + online_proj_one = self.online_encoder(img=image_one, pos=pt_one) + online_proj_two = self.online_encoder(img=image_two, pos=pt_two) online_pred_one = self.online_predictor(online_proj_one) online_pred_two = self.online_predictor(online_proj_two) with torch.no_grad(): target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder - target_proj_one = target_encoder(image_one, pt_one).detach() - target_proj_two = target_encoder(image_two, pt_two).detach() + target_proj_one = target_encoder(img=image_one, pos=pt_one).detach() + target_proj_two = target_encoder(img=image_two, pos=pt_two).detach() + return online_pred_one, online_pred_two, target_proj_one, target_proj_two + + def forward_normal(self, image): + online_pred_one, online_pred_two, target_proj_one, target_proj_two = self.get_predictions_and_projections(image) loss_one = loss_fn(online_pred_one, target_proj_two.detach()) loss_two = loss_fn(online_pred_two, target_proj_one.detach()) @@ -314,6 +324,35 @@ class BYOL(nn.Module): loss = loss_one + loss_two return loss.mean() + def forward_contrastive(self, image): + online_pred_one_1, online_pred_two_1, target_proj_one_1, target_proj_two_1 = self.get_predictions_and_projections(image) + loss_one = loss_fn(online_pred_one_1, target_proj_two_1.detach()) + loss_two = loss_fn(online_pred_two_1, target_proj_one_1.detach()) + loss = loss_one + loss_two + + online_pred_one_2, online_pred_two_2, target_proj_one_2, target_proj_two_2 = self.get_predictions_and_projections(image) + loss_one = loss_fn(online_pred_one_2, target_proj_two_2.detach()) + loss_two = loss_fn(online_pred_two_2, target_proj_one_2.detach()) + loss = (loss + loss_one + loss_two).mean() + + contrastive_loss = torch.cat([loss_fn(online_pred_one_1, target_proj_two_2), + loss_fn(online_pred_two_1, target_proj_one_2), + loss_fn(online_pred_one_2, target_proj_two_1), + loss_fn(online_pred_two_2, target_proj_one_1)], dim=0) + k = contrastive_loss.shape[0] // 2 # Take half of the total contrastive loss predictions. + contrastive_loss = torch.topk(contrastive_loss, k, dim=0).values.mean() + + self.logs_loss = loss.detach() + self.logs_closs = contrastive_loss.detach() + + return loss - contrastive_los00s + + def forward(self, image): + if self.contrastive: + return self.forward_contrastive(image) + else: + return self.forward_normal(image) + if __name__ == '__main__': pa = PointwiseAugmentor(256) @@ -331,4 +370,4 @@ if __name__ == '__main__': @register_model def register_pixel_local_byol(opt_net, opt): subnet = create_model(opt, opt_net['subnet']) - return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer']) \ No newline at end of file + return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], contrastive=opt_net['contrastive']) \ No newline at end of file diff --git a/codes/models/segformer/segformer.py b/codes/models/segformer/segformer.py index a2bdb103..04daa225 100644 --- a/codes/models/segformer/segformer.py +++ b/codes/models/segformer/segformer.py @@ -94,14 +94,21 @@ class Segformer(nn.Module): self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(layers)]) self.tail = Tail() - def forward(self, x, pos): - layers = self.backbone(x) - set = [] + def forward(self, img=None, layers=None, pos=None, return_layers=False): + assert img is not None or layers is not None + if img is not None: + bs = img.shape[0] + layers = self.backbone(img) + else: + bs = layers[0].shape[0] + if return_layers: + return layers # A single position can be optionally given, in which case we need to expand it to represent the entire input. if pos.shape == (2,): - pos = pos.unsqueeze(0).repeat(x.shape[0],1) + pos = pos.unsqueeze(0).repeat(bs, 1) + set = [] pos = pos // 4 for layer_out, dilator in zip(layers, self.dilators): for subdilator in dilator: @@ -124,4 +131,4 @@ if __name__ == '__main__': model = Segformer().to('cuda') for j in tqdm(range(1000)): test_tensor = torch.randn(64,3,224,224).cuda() - print(model(test_tensor, torch.randint(0,224,(64,2)).cuda()).shape) \ No newline at end of file + print(model(img=test_tensor, pos=torch.randint(0,224,(64,2)).cuda()).shape) \ No newline at end of file diff --git a/codes/scripts/byol/byol_segformer_playground.py b/codes/scripts/byol/byol_segformer_playground.py new file mode 100644 index 00000000..a5bfcb0b --- /dev/null +++ b/codes/scripts/byol/byol_segformer_playground.py @@ -0,0 +1,240 @@ +import os +import shutil + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +from PIL import Image +from torch.utils.data import DataLoader +from torchvision.transforms import ToTensor, Resize +from tqdm import tqdm +import numpy as np + +import utils +from data.image_folder_dataset import ImageFolderDataset +from models.resnet_with_checkpointing import resnet50 +from models.segformer.segformer import Segformer +from models.spinenet_arch import SpineNet + + +# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved +# and the distance is computed across the channel dimension. +from utils import util +from utils.kmeans import kmeans, kmeans_predict +from utils.options import dict_to_nonedict + + +def structural_euc_dist(x, y): + diff = torch.square(x - y) + sum = torch.sum(diff, dim=-1) + return torch.sqrt(sum) + + +def cosine_similarity(x, y): + x = norm(x) + y = norm(y) + return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself. + + +def key_value_difference(x, y): + x = F.normalize(x, dim=-1, p=2) + y = F.normalize(y, dim=-1, p=2) + return 2 - 2 * (x * y).sum(dim=-1) + + +def norm(x): + sh = x.shape + sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))]) + return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r) + + +def im_norm(x): + return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5 + + +def get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=True): + dataset_opt = dict_to_nonedict({ + 'name': 'amalgam', + #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped2'], + #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], + #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_tiled_filtered_flattened'], + #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], + 'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], + 'weights': [1], + 'target_size': target_size, + 'force_multiple': 32, + 'normalize': 'imagenet', + 'scale': 1 + }) + dataset = ImageFolderDataset(dataset_opt) + return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle) + + +def _find_layer(net, layer_name): + if type(layer_name) == str: + modules = dict([*net.named_modules()]) + return modules.get(layer_name, None) + elif type(layer_name) == int: + children = [*net.children()] + return children[layer_name] + return None + + +layer_hooked_value = None +def _hook(_, __, output): + global layer_hooked_value + layer_hooked_value = output + + +def register_hook(net, layer_name): + layer = _find_layer(net, layer_name) + assert layer is not None, f'hidden layer ({self.layer}) not found' + layer.register_forward_hook(_hook) + + +def get_latent_for_img(model, img): + img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0) + _, _, h, w = img_t.shape + # Center crop img_t and resize to 224. + d = min(h, w) + dh, dw = (h-d)//2, (w-d)//2 + if dw != 0: + img_t = img_t[:, :, :, dw:-dw] + elif dh != 0: + img_t = img_t[:, :, dh:-dh, :] + img_t = img_t[:,:3,:,:] + img_t = torch.nn.functional.interpolate(img_t, size=(224, 224), mode="area") + model(img_t) + latent = layer_hooked_value + return latent + + +def produce_latent_dict(model): + batch_size = 32 + num_workers = 4 + dataloader = get_image_folder_dataloader(batch_size, num_workers) + id = 0 + paths = [] + latents = [] + points = [] + for batch in tqdm(dataloader): + hq = batch['hq'].to('cuda') + # Pull several points from every image. + for k in range(10): + _, _, h, _ = hq.shape + point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device) + model(hq, point) + l = layer_hooked_value.cpu().split(1, dim=0) + latents.extend(l) + points.extend([point for p in range(batch_size)]) + paths.extend(batch['HQ_path']) + id += batch_size + if id > 10000: + print("Saving checkpoint..") + torch.save((latents, points, paths), '../results_segformer.pth') + id = 0 + + +def find_similar_latents(model, compare_fn=structural_euc_dist): + global layer_hooked_value + + img = 'D:\\dlas\\results\\bobz.png' + #img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg' + output_path = '../../../results/byol_resnet_similars' + os.makedirs(output_path, exist_ok=True) + imglatent = get_latent_for_img(model, img).squeeze().unsqueeze(0) + _, c = imglatent.shape + + batch_size = 512 + num_workers = 8 + dataloader = get_image_folder_dataloader(batch_size, num_workers) + id = 0 + output_batch = 1 + results = [] + result_paths = [] + for batch in tqdm(dataloader): + hq = batch['hq'].to('cuda') + model(hq) + latent = layer_hooked_value.clone().squeeze() + compared = compare_fn(imglatent.repeat(latent.shape[0], 1), latent) + results.append(compared.cpu()) + result_paths.extend(batch['HQ_path']) + id += batch_size + if id > 10000: + k = 200 + results = torch.cat(results, dim=0) + vals, inds = torch.topk(results, k, largest=False) + for i in inds: + mag = int(results[i].item() * 1000) + shutil.copy(result_paths[i], os.path.join(output_path, f'{mag:05}_{output_batch}_{i}.jpg')) + results = [] + result_paths = [] + id = 0 + + +def build_kmeans(): + latents, _, _ = torch.load('../results_segformer.pth') + latents = torch.cat(latents, dim=0).squeeze().to('cuda') + cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=16, distance="euclidean", device=torch.device('cuda:0')) + torch.save((cluster_ids_x, cluster_centers), '../k_means_segformer.pth') + + +class UnNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + Returns: + Tensor: Normalized image. + """ + for t, m, s in zip(tensor, self.mean, self.std): + t.mul_(s).add_(m) + # The normalize code -> t.sub_(m).div_(s) + return tensor + + +def use_kmeans(): + output = "../results/k_means_segformer/" + _, centers = torch.load('../k_means_segformer.pth') + centers = centers.to('cuda') + batch_size = 32 + num_workers = 1 + dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=True) + denorm = UnNormalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + for i, batch in enumerate(tqdm(dataloader)): + hq = batch['hq'].to('cuda') + _,_,h,w = hq.shape + point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device) + model(hq, point) + l = layer_hooked_value.clone().squeeze() + pred = kmeans_predict(l, centers) + hq = denorm(hq * .5) + hq[:,:,point[0]-5:point[0]+5,point[1]-5:point[1]+5] *= 2 + for b in range(pred.shape[0]): + outpath = os.path.join(output, str(pred[b].item())) + os.makedirs(outpath, exist_ok=True) + torchvision.utils.save_image(hq[b], os.path.join(outpath, f'{i*batch_size+b}.png')) + + +if __name__ == '__main__': + pretrained_path = '../../../experiments/segformer_byol_only.pth' + model = Segformer().to('cuda') + sd = torch.load(pretrained_path) + resnet_sd = {} + for k, v in sd.items(): + if 'target_encoder.net.' in k: + resnet_sd[k.replace('target_encoder.net.', '')] = v + model.load_state_dict(resnet_sd, strict=True) + model.eval() + register_hook(model, 'tail') + + with torch.no_grad(): + #find_similar_latents(model, structural_euc_dist) + #produce_latent_dict(model) + #build_kmeans() + use_kmeans() diff --git a/codes/scripts/byol/tsne_torch.py b/codes/scripts/byol/tsne_torch.py index 9a70eb77..03570a83 100644 --- a/codes/scripts/byol/tsne_torch.py +++ b/codes/scripts/byol/tsne_torch.py @@ -344,11 +344,75 @@ def plot_pixel_level_results_as_image_graph(): pyplot.savefig('tsne_pix.pdf') +def run_tsne_segformer(): + print("Run Y = tsne.tsne(X, no_dims, perplexity) to perform t-SNE on your dataset.") + + limit = 10000 + X, points, files = torch.load('../results_segformer.pth') + zipped = list(zip(X, points, files)) + shuffle(zipped) + X, points, files = zip(*zipped) + X = torch.cat(X, dim=0).squeeze()[:limit] + labels = np.zeros(X.shape[0]) # We don't have any labels.. + + # confirm that x file get same number point than label file + # otherwise may cause error in scatter + assert(len(X[:, 0])==len(X[:,1])) + assert(len(X)==len(labels)) + + with torch.no_grad(): + Y = tsne(X, 2, 1024, 20.0) + + if opt.cuda: + Y = Y.cpu().numpy() + + # You may write result in two files + # print("Save Y values in file") + # Y1 = open("y1.txt", 'w') + # Y2 = open('y2.txt', 'w') + # for i in range(Y.shape[0]): + # Y1.write(str(Y[i,0])+"\n") + # Y2.write(str(Y[i,1])+"\n") + + pyplot.scatter(Y[:, 0], Y[:, 1], 20, labels) + pyplot.show() + torch.save((Y, points, files[:limit]), "../tsne_output.pth") + + +# Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne +# spectrum. +def plot_segformer_results_as_image_graph(): + Y, points, files = torch.load('../tsne_output.pth') + fig, ax = pyplot.subplots() + fig.set_size_inches(200,200,forward=True) + ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]])) + ax.autoscale() + + margins = 32 + for b in tqdm(range(Y.shape[0])): + imgfile = files[b] + baseim = pyplot.imread(imgfile) + ct, cl = points[b] + + im = baseim[(ct-margins):(ct+margins), + (cl-margins):(cl+margins),:] + im = OffsetImage(im, zoom=1) + ab = AnnotationBbox(im, (Y[b, 0], Y[b, 1]), xycoords='data', frameon=False) + ax.add_artist(ab) + ax.scatter(Y[:, 0], Y[:, 1]) + + pyplot.savefig('tsne_segformer.pdf') + + if __name__ == "__main__": # For use with instance-level results (e.g. from byol_resnet_playground.py) #run_tsne_instance_level() - plot_instance_level_results_as_image_graph() + #plot_instance_level_results_as_image_graph() # For use with pixel-level results (e.g. from byol_uresnet_playground) #run_tsne_pixel_level() - #plot_pixel_level_results_as_image_graph() \ No newline at end of file + #plot_pixel_level_results_as_image_graph() + + # For use with segformer results + #run_tsne_segformer() + plot_segformer_results_as_image_graph() \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index b809a286..99bf1ff8 100644 --- a/codes/train.py +++ b/codes/train.py @@ -295,7 +295,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_segformer_xx.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_segformer_contrastive_xx.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/eval/single_point_pair_contrastive_eval.py b/codes/trainer/eval/single_point_pair_contrastive_eval.py index 6a84cb34..8cc0278f 100644 --- a/codes/trainer/eval/single_point_pair_contrastive_eval.py +++ b/codes/trainer/eval/single_point_pair_contrastive_eval.py @@ -33,8 +33,8 @@ class SinglePointPairContrastiveEval(evaluator.Evaluator): distances = [] l2 = MSELoss() for i, data in tqdm(enumerate(dl)): - latent1 = self.model(data['img1'].to(dev), torch.stack(data['coords1'], dim=1).to(dev)) - latent2 = self.model(data['img2'].to(dev), torch.stack(data['coords2'], dim=1).to(dev)) + latent1 = self.model(img=data['img1'].to(dev), pos=torch.stack(data['coords1'], dim=1).to(dev)) + latent2 = self.model(img=data['img2'].to(dev), pos=torch.stack(data['coords2'], dim=1).to(dev)) distances.append(l2(latent1, latent2)) if i * self.batch_sz >= self.eval_qty: break @@ -52,7 +52,7 @@ class SinglePointPairContrastiveEval(evaluator.Evaluator): diff = dissimilars.item() - similars.item() print(f"Eval done. val_similar_lq: {similars.item()}; val_dissimilar_l2: {dissimilars.item()}; val_diff: {diff}") self.model.train() - return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item(), "val_diff": diff.item()} + return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item(), "val_diff": diff} if __name__ == '__main__':