Add dataset, ui for labeling and evaluator for pointwise classification

This commit is contained in:
James Betker 2021-04-23 17:17:13 -06:00
parent fc623d4b5a
commit 23e01314d4
3 changed files with 236 additions and 0 deletions

View File

@ -0,0 +1,84 @@
import glob
import itertools
import random
import cv2
import kornia
import numpy as np
import pytorch_ssim
import torch
import os
import torchvision
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import Normalize
from tqdm import tqdm
from data import util
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
from data.image_corruptor import ImageCorruptor
from data.image_label_parser import VsNetImageLabeler
from utils.util import opt_get
class ImagePairWithCorrespondingPointsDataset(Dataset):
def __init__(self, opt):
self.opt = opt
self.path = opt['path']
self.pairs = list(filter(lambda f: not os.path.isdir(f), os.listdir(self.path)))
self.transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
self.size = opt['size']
def __getitem__(self, item):
dir = self.pairs[item]
img1 = self.transforms(Image.open(os.path.join(self.path, dir, "1.jpg")))
img2 = self.transforms(Image.open(os.path.join(self.path, dir, "2.jpg")))
coords1, coords2 = torch.load(os.path.join(self.path, dir, "coords.pth"))
assert img1.shape[-2] == img1.shape[-1]
assert img2.shape[-2] == img2.shape[-1]
if img1.shape[-1] != self.size:
scale = img1.shape[-1] / self.size
assert(int(scale) == scale) # We will only downsample to even resolutions.
scale = 1 / scale
img1 = torch.nn.functional.interpolate(img1.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
coords1 = [int(c * scale) for c in coords1]
if img2.shape[-1] != self.size:
scale = img2.shape[-1] / self.size
assert(int(scale) == scale) # We will only downsample to even resolutions.
scale = 1 / scale
img2 = torch.nn.functional.interpolate(img2.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
coords2 = [int(c * scale) for c in coords2]
coords1 = (coords1[1], coords1[0]) # The UI puts these out backwards (x,y). Flip them.
coords2 = (coords2[1], coords2[0])
return {
'img1': img1,
'img2': img2,
'coords1': coords1,
'coords2': coords2
}
def __len__(self):
return len(self.pairs)
if __name__ == '__main__':
opt = {
'path': 'F:\\dlas\\codes\\scripts\\ui\\image_pair_labeler\\results',
'size': 256
}
output_path = '.'
ds = DataLoader(ImagePairWithCorrespondingPointsDataset(opt), shuffle=True, num_workers=0)
for i, d in tqdm(enumerate(ds)):
i1 = d['img1']
i2 = d['img2']
c1 = d['coords1']
c2 = d['coords2']
i1[:,:,c1[0]-3:c1[0]+3,c1[1]-3:c1[1]+3] = 0
i2[:,:,c2[0]-3:c2[0]+3,c2[1]-3:c2[1]+3] = 0
torchvision.utils.save_image(i1, f'{output_path}\\{i}_1.png')
torchvision.utils.save_image(i2, f'{output_path}\\{i}_2.png')

View File

@ -0,0 +1,104 @@
# Script that builds and launches a tkinter UI for labeling similar points between two images.
import os
import tkinter as tk
from glob import glob
from random import choices
import torch
from PIL import ImageTk, Image
# Globals used to define state that event handlers might operate on.
imgs_list = []
widgets = None
cur_img_1, cur_img_2 = None, None
pil_img_1, pil_img_2 = None, None
pending_labels = []
mode_select_image_1 = True
img_count = 1
img_loc_1 = None
output_location = "results"
def update_mode_label():
global widgets, mode_select_image_1, img_count
image_widget_1, image_widget_2, mode_label = widgets
mode_str = "Select point in image 1" if mode_select_image_1 else "Select point in image 2"
mode_label.config(text="%s; Saved images: %i" % (mode_str, img_count))
# Handles key presses, which are interpreted as requests to categorize a currently active image patch.
def key_press(event):
global batch_gen, labeler, pending_labels
if event.char == '\t':
next_images()
update_mode_label()
def click(event):
global img_loc_1, mode_select_image_1, pil_img_1, pil_img_2, img_count
x, y = event.x, event.y
if x > 512 or y > 512:
print(f"Bounds error {x} {y}")
return
print(f"Detected click. {x} {y}")
if mode_select_image_1:
img_loc_1 = x, y
mode_select_image_1 = False
else:
ofolder = f'{output_location}/{img_count}'
os.makedirs(ofolder)
pil_img_1.save(os.path.join(ofolder, "1.jpg"))
pil_img_2.save(os.path.join(ofolder, "2.jpg"))
torch.save([img_loc_1, (x,y)], os.path.join(ofolder, "coords.pth"))
img_count = img_count + 1
mode_select_image_1 = True
next_images()
update_mode_label()
def load_image_into_pane(img_path, pane, size=512):
pil_img = Image.open(img_path)
pil_img = pil_img.resize((size,size))
tk_picture = ImageTk.PhotoImage(pil_img)
pane.image = tk_picture
pane.configure(image=tk_picture)
return pil_img
def next_images():
global imgs_list, widgets, cur_img_1, cur_img_2, pil_img_1, pil_img_2
image_widget_1, image_widget_2, mode_label = widgets
cur_img_1, cur_img_2 = choices(imgs_list, k=2)
pil_img_1 = load_image_into_pane(cur_img_1, image_widget_1)
pil_img_2 = load_image_into_pane(cur_img_2, image_widget_2)
if __name__ == '__main__':
os.makedirs(output_location, exist_ok=True)
window = tk.Tk()
window.title("Image pair labeler UI")
window.geometry('1024x620+100+100')
# Load images
imgs_list = glob("E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new\\*.jpg")
# Photo view.
image_widget_1 = tk.Label(window)
image_widget_1.place(x=0, y=0, width=512, height=512)
image_widget_2 = tk.Label(window)
image_widget_2.place(x=512, y=0, width=512, height=512)
# Labels
mode_label = tk.Label(window, text="", anchor="w")
mode_label.place(x=20, y=590, width=400, height=20)
widgets = (image_widget_1, image_widget_2, mode_label)
window.bind("<Tab>", key_press) # Skip current patch
window.bind("<Button-1>", click)
next_images()
update_mode_label()
window.mainloop()

View File

@ -0,0 +1,48 @@
import os
import torch
import os.path as osp
import torchvision
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from tqdm import tqdm
import trainer.eval.evaluator as evaluator
from pytorch_fid import fid_score
from data.image_pair_with_corresponding_points_dataset import ImagePairWithCorrespondingPointsDataset
from utils.util import opt_get
# Uses two datasets: a "similar" and "dissimilar" dataset, each of which contains pairs of images and similar/dissimilar
# points in those datasets. Uses the provided network to compute a latent vector for both similar and dissimilar.
# Reports a score for the l2 distance of both. A properly trained network will show similar points getting closer while
# dissimilar points remain constant or get further apart.
class SinglePointPairContrastiveEval(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
self.batch_sz = opt_eval['batch_size']
self.eval_qty = opt_eval['quantity']
assert self.eval_qty % self.batch_sz == 0
self.similar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['similar_set_args']), shuffle=False, batch_size=self.batch_sz)
self.dissimilar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['dissimilar_set_args']), shuffle=False, batch_size=self.batch_sz)
def get_l2_score(self, dl):
distances = []
l2 = MSELoss()
for i, data in tqdm(enumerate(dl)):
latent1 = self.model(data['img1'], data['coords1'])
latent2 = self.model(data['img2'], data['coords2'])
distances.append(l2(latent1, latent2))
if i * self.batch_sz >= self.eval_qty:
break
return torch.stack(distances).mean()
def perform_eval(self):
self.model.eval()
print("Computing contrastive eval on similar set")
similars = self.get_l2_score(self.similar_set)
print("Computing contrastive eval on dissimilar set")
dissimilars = self.get_l2_score(self.dissimilar_set)
print(f"Eval done. val_similar_lq: {similars.item()}; val_dissimilar_l2: {dissimilars.item()}")
return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item()}