From 23e01314d4696e05aba5e5228a12b9435e2d0d11 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 23 Apr 2021 17:17:13 -0600 Subject: [PATCH] Add dataset, ui for labeling and evaluator for pointwise classification --- ..._pair_with_corresponding_points_dataset.py | 84 ++++++++++++++ .../ui/image_pair_labeler/image_pair_ui.py | 104 ++++++++++++++++++ .../single_point_pair_contrastive_eval.py | 48 ++++++++ 3 files changed, 236 insertions(+) create mode 100644 codes/data/image_pair_with_corresponding_points_dataset.py create mode 100644 codes/scripts/ui/image_pair_labeler/image_pair_ui.py create mode 100644 codes/trainer/eval/single_point_pair_contrastive_eval.py diff --git a/codes/data/image_pair_with_corresponding_points_dataset.py b/codes/data/image_pair_with_corresponding_points_dataset.py new file mode 100644 index 00000000..c6650080 --- /dev/null +++ b/codes/data/image_pair_with_corresponding_points_dataset.py @@ -0,0 +1,84 @@ +import glob +import itertools +import random + +import cv2 +import kornia +import numpy as np +import pytorch_ssim +import torch +import os + +import torchvision +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from torchvision.transforms import Normalize +from tqdm import tqdm + +from data import util +# Builds a dataset created from a simple folder containing a list of training/test/validation images. +from data.image_corruptor import ImageCorruptor +from data.image_label_parser import VsNetImageLabeler +from utils.util import opt_get + + +class ImagePairWithCorrespondingPointsDataset(Dataset): + def __init__(self, opt): + self.opt = opt + self.path = opt['path'] + self.pairs = list(filter(lambda f: not os.path.isdir(f), os.listdir(self.path))) + self.transforms = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + self.size = opt['size'] + + + def __getitem__(self, item): + dir = self.pairs[item] + img1 = self.transforms(Image.open(os.path.join(self.path, dir, "1.jpg"))) + img2 = self.transforms(Image.open(os.path.join(self.path, dir, "2.jpg"))) + coords1, coords2 = torch.load(os.path.join(self.path, dir, "coords.pth")) + assert img1.shape[-2] == img1.shape[-1] + assert img2.shape[-2] == img2.shape[-1] + if img1.shape[-1] != self.size: + scale = img1.shape[-1] / self.size + assert(int(scale) == scale) # We will only downsample to even resolutions. + scale = 1 / scale + img1 = torch.nn.functional.interpolate(img1.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0) + coords1 = [int(c * scale) for c in coords1] + if img2.shape[-1] != self.size: + scale = img2.shape[-1] / self.size + assert(int(scale) == scale) # We will only downsample to even resolutions. + scale = 1 / scale + img2 = torch.nn.functional.interpolate(img2.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0) + coords2 = [int(c * scale) for c in coords2] + coords1 = (coords1[1], coords1[0]) # The UI puts these out backwards (x,y). Flip them. + coords2 = (coords2[1], coords2[0]) + return { + 'img1': img1, + 'img2': img2, + 'coords1': coords1, + 'coords2': coords2 + } + + def __len__(self): + return len(self.pairs) + +if __name__ == '__main__': + opt = { + 'path': 'F:\\dlas\\codes\\scripts\\ui\\image_pair_labeler\\results', + 'size': 256 + } + output_path = '.' + + ds = DataLoader(ImagePairWithCorrespondingPointsDataset(opt), shuffle=True, num_workers=0) + for i, d in tqdm(enumerate(ds)): + i1 = d['img1'] + i2 = d['img2'] + c1 = d['coords1'] + c2 = d['coords2'] + i1[:,:,c1[0]-3:c1[0]+3,c1[1]-3:c1[1]+3] = 0 + i2[:,:,c2[0]-3:c2[0]+3,c2[1]-3:c2[1]+3] = 0 + torchvision.utils.save_image(i1, f'{output_path}\\{i}_1.png') + torchvision.utils.save_image(i2, f'{output_path}\\{i}_2.png') \ No newline at end of file diff --git a/codes/scripts/ui/image_pair_labeler/image_pair_ui.py b/codes/scripts/ui/image_pair_labeler/image_pair_ui.py new file mode 100644 index 00000000..186df7d9 --- /dev/null +++ b/codes/scripts/ui/image_pair_labeler/image_pair_ui.py @@ -0,0 +1,104 @@ +# Script that builds and launches a tkinter UI for labeling similar points between two images. +import os +import tkinter as tk +from glob import glob +from random import choices + +import torch +from PIL import ImageTk, Image + +# Globals used to define state that event handlers might operate on. +imgs_list = [] +widgets = None +cur_img_1, cur_img_2 = None, None +pil_img_1, pil_img_2 = None, None +pending_labels = [] +mode_select_image_1 = True +img_count = 1 +img_loc_1 = None +output_location = "results" + + +def update_mode_label(): + global widgets, mode_select_image_1, img_count + image_widget_1, image_widget_2, mode_label = widgets + mode_str = "Select point in image 1" if mode_select_image_1 else "Select point in image 2" + mode_label.config(text="%s; Saved images: %i" % (mode_str, img_count)) + + +# Handles key presses, which are interpreted as requests to categorize a currently active image patch. +def key_press(event): + global batch_gen, labeler, pending_labels + + if event.char == '\t': + next_images() + + update_mode_label() + + +def click(event): + global img_loc_1, mode_select_image_1, pil_img_1, pil_img_2, img_count + x, y = event.x, event.y + if x > 512 or y > 512: + print(f"Bounds error {x} {y}") + return + + print(f"Detected click. {x} {y}") + if mode_select_image_1: + img_loc_1 = x, y + mode_select_image_1 = False + else: + ofolder = f'{output_location}/{img_count}' + os.makedirs(ofolder) + pil_img_1.save(os.path.join(ofolder, "1.jpg")) + pil_img_2.save(os.path.join(ofolder, "2.jpg")) + torch.save([img_loc_1, (x,y)], os.path.join(ofolder, "coords.pth")) + img_count = img_count + 1 + mode_select_image_1 = True + next_images() + update_mode_label() + + +def load_image_into_pane(img_path, pane, size=512): + pil_img = Image.open(img_path) + pil_img = pil_img.resize((size,size)) + tk_picture = ImageTk.PhotoImage(pil_img) + pane.image = tk_picture + pane.configure(image=tk_picture) + return pil_img + +def next_images(): + global imgs_list, widgets, cur_img_1, cur_img_2, pil_img_1, pil_img_2 + image_widget_1, image_widget_2, mode_label = widgets + + cur_img_1, cur_img_2 = choices(imgs_list, k=2) + pil_img_1 = load_image_into_pane(cur_img_1, image_widget_1) + pil_img_2 = load_image_into_pane(cur_img_2, image_widget_2) + +if __name__ == '__main__': + os.makedirs(output_location, exist_ok=True) + + window = tk.Tk() + window.title("Image pair labeler UI") + window.geometry('1024x620+100+100') + + # Load images + imgs_list = glob("E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new\\*.jpg") + + # Photo view. + image_widget_1 = tk.Label(window) + image_widget_1.place(x=0, y=0, width=512, height=512) + image_widget_2 = tk.Label(window) + image_widget_2.place(x=512, y=0, width=512, height=512) + + # Labels + mode_label = tk.Label(window, text="", anchor="w") + mode_label.place(x=20, y=590, width=400, height=20) + + widgets = (image_widget_1, image_widget_2, mode_label) + + window.bind("", key_press) # Skip current patch + window.bind("", click) + next_images() + update_mode_label() + window.mainloop() diff --git a/codes/trainer/eval/single_point_pair_contrastive_eval.py b/codes/trainer/eval/single_point_pair_contrastive_eval.py new file mode 100644 index 00000000..fccebbd5 --- /dev/null +++ b/codes/trainer/eval/single_point_pair_contrastive_eval.py @@ -0,0 +1,48 @@ +import os + +import torch +import os.path as osp +import torchvision +from torch.nn import MSELoss +from torch.utils.data import DataLoader +from tqdm import tqdm + +import trainer.eval.evaluator as evaluator +from pytorch_fid import fid_score + +from data.image_pair_with_corresponding_points_dataset import ImagePairWithCorrespondingPointsDataset +from utils.util import opt_get + +# Uses two datasets: a "similar" and "dissimilar" dataset, each of which contains pairs of images and similar/dissimilar +# points in those datasets. Uses the provided network to compute a latent vector for both similar and dissimilar. +# Reports a score for the l2 distance of both. A properly trained network will show similar points getting closer while +# dissimilar points remain constant or get further apart. +class SinglePointPairContrastiveEval(evaluator.Evaluator): + def __init__(self, model, opt_eval, env): + super().__init__(model, opt_eval, env) + self.batch_sz = opt_eval['batch_size'] + self.eval_qty = opt_eval['quantity'] + assert self.eval_qty % self.batch_sz == 0 + self.similar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['similar_set_args']), shuffle=False, batch_size=self.batch_sz) + self.dissimilar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['dissimilar_set_args']), shuffle=False, batch_size=self.batch_sz) + + def get_l2_score(self, dl): + distances = [] + l2 = MSELoss() + for i, data in tqdm(enumerate(dl)): + latent1 = self.model(data['img1'], data['coords1']) + latent2 = self.model(data['img2'], data['coords2']) + distances.append(l2(latent1, latent2)) + if i * self.batch_sz >= self.eval_qty: + break + + return torch.stack(distances).mean() + + def perform_eval(self): + self.model.eval() + print("Computing contrastive eval on similar set") + similars = self.get_l2_score(self.similar_set) + print("Computing contrastive eval on dissimilar set") + dissimilars = self.get_l2_score(self.dissimilar_set) + print(f"Eval done. val_similar_lq: {similars.item()}; val_dissimilar_l2: {dissimilars.item()}") + return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item()} \ No newline at end of file