Add dataset, ui for labeling and evaluator for pointwise classification
This commit is contained in:
parent
fc623d4b5a
commit
23e01314d4
84
codes/data/image_pair_with_corresponding_points_dataset.py
Normal file
84
codes/data/image_pair_with_corresponding_points_dataset.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
import glob
|
||||||
|
import itertools
|
||||||
|
import random
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import kornia
|
||||||
|
import numpy as np
|
||||||
|
import pytorch_ssim
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torchvision
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.transforms import Normalize
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from data import util
|
||||||
|
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
|
||||||
|
from data.image_corruptor import ImageCorruptor
|
||||||
|
from data.image_label_parser import VsNetImageLabeler
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
class ImagePairWithCorrespondingPointsDataset(Dataset):
|
||||||
|
def __init__(self, opt):
|
||||||
|
self.opt = opt
|
||||||
|
self.path = opt['path']
|
||||||
|
self.pairs = list(filter(lambda f: not os.path.isdir(f), os.listdir(self.path)))
|
||||||
|
self.transforms = transforms.Compose([transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||||
|
])
|
||||||
|
self.size = opt['size']
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
dir = self.pairs[item]
|
||||||
|
img1 = self.transforms(Image.open(os.path.join(self.path, dir, "1.jpg")))
|
||||||
|
img2 = self.transforms(Image.open(os.path.join(self.path, dir, "2.jpg")))
|
||||||
|
coords1, coords2 = torch.load(os.path.join(self.path, dir, "coords.pth"))
|
||||||
|
assert img1.shape[-2] == img1.shape[-1]
|
||||||
|
assert img2.shape[-2] == img2.shape[-1]
|
||||||
|
if img1.shape[-1] != self.size:
|
||||||
|
scale = img1.shape[-1] / self.size
|
||||||
|
assert(int(scale) == scale) # We will only downsample to even resolutions.
|
||||||
|
scale = 1 / scale
|
||||||
|
img1 = torch.nn.functional.interpolate(img1.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
|
||||||
|
coords1 = [int(c * scale) for c in coords1]
|
||||||
|
if img2.shape[-1] != self.size:
|
||||||
|
scale = img2.shape[-1] / self.size
|
||||||
|
assert(int(scale) == scale) # We will only downsample to even resolutions.
|
||||||
|
scale = 1 / scale
|
||||||
|
img2 = torch.nn.functional.interpolate(img2.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
|
||||||
|
coords2 = [int(c * scale) for c in coords2]
|
||||||
|
coords1 = (coords1[1], coords1[0]) # The UI puts these out backwards (x,y). Flip them.
|
||||||
|
coords2 = (coords2[1], coords2[0])
|
||||||
|
return {
|
||||||
|
'img1': img1,
|
||||||
|
'img2': img2,
|
||||||
|
'coords1': coords1,
|
||||||
|
'coords2': coords2
|
||||||
|
}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.pairs)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
opt = {
|
||||||
|
'path': 'F:\\dlas\\codes\\scripts\\ui\\image_pair_labeler\\results',
|
||||||
|
'size': 256
|
||||||
|
}
|
||||||
|
output_path = '.'
|
||||||
|
|
||||||
|
ds = DataLoader(ImagePairWithCorrespondingPointsDataset(opt), shuffle=True, num_workers=0)
|
||||||
|
for i, d in tqdm(enumerate(ds)):
|
||||||
|
i1 = d['img1']
|
||||||
|
i2 = d['img2']
|
||||||
|
c1 = d['coords1']
|
||||||
|
c2 = d['coords2']
|
||||||
|
i1[:,:,c1[0]-3:c1[0]+3,c1[1]-3:c1[1]+3] = 0
|
||||||
|
i2[:,:,c2[0]-3:c2[0]+3,c2[1]-3:c2[1]+3] = 0
|
||||||
|
torchvision.utils.save_image(i1, f'{output_path}\\{i}_1.png')
|
||||||
|
torchvision.utils.save_image(i2, f'{output_path}\\{i}_2.png')
|
104
codes/scripts/ui/image_pair_labeler/image_pair_ui.py
Normal file
104
codes/scripts/ui/image_pair_labeler/image_pair_ui.py
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
# Script that builds and launches a tkinter UI for labeling similar points between two images.
|
||||||
|
import os
|
||||||
|
import tkinter as tk
|
||||||
|
from glob import glob
|
||||||
|
from random import choices
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import ImageTk, Image
|
||||||
|
|
||||||
|
# Globals used to define state that event handlers might operate on.
|
||||||
|
imgs_list = []
|
||||||
|
widgets = None
|
||||||
|
cur_img_1, cur_img_2 = None, None
|
||||||
|
pil_img_1, pil_img_2 = None, None
|
||||||
|
pending_labels = []
|
||||||
|
mode_select_image_1 = True
|
||||||
|
img_count = 1
|
||||||
|
img_loc_1 = None
|
||||||
|
output_location = "results"
|
||||||
|
|
||||||
|
|
||||||
|
def update_mode_label():
|
||||||
|
global widgets, mode_select_image_1, img_count
|
||||||
|
image_widget_1, image_widget_2, mode_label = widgets
|
||||||
|
mode_str = "Select point in image 1" if mode_select_image_1 else "Select point in image 2"
|
||||||
|
mode_label.config(text="%s; Saved images: %i" % (mode_str, img_count))
|
||||||
|
|
||||||
|
|
||||||
|
# Handles key presses, which are interpreted as requests to categorize a currently active image patch.
|
||||||
|
def key_press(event):
|
||||||
|
global batch_gen, labeler, pending_labels
|
||||||
|
|
||||||
|
if event.char == '\t':
|
||||||
|
next_images()
|
||||||
|
|
||||||
|
update_mode_label()
|
||||||
|
|
||||||
|
|
||||||
|
def click(event):
|
||||||
|
global img_loc_1, mode_select_image_1, pil_img_1, pil_img_2, img_count
|
||||||
|
x, y = event.x, event.y
|
||||||
|
if x > 512 or y > 512:
|
||||||
|
print(f"Bounds error {x} {y}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Detected click. {x} {y}")
|
||||||
|
if mode_select_image_1:
|
||||||
|
img_loc_1 = x, y
|
||||||
|
mode_select_image_1 = False
|
||||||
|
else:
|
||||||
|
ofolder = f'{output_location}/{img_count}'
|
||||||
|
os.makedirs(ofolder)
|
||||||
|
pil_img_1.save(os.path.join(ofolder, "1.jpg"))
|
||||||
|
pil_img_2.save(os.path.join(ofolder, "2.jpg"))
|
||||||
|
torch.save([img_loc_1, (x,y)], os.path.join(ofolder, "coords.pth"))
|
||||||
|
img_count = img_count + 1
|
||||||
|
mode_select_image_1 = True
|
||||||
|
next_images()
|
||||||
|
update_mode_label()
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_into_pane(img_path, pane, size=512):
|
||||||
|
pil_img = Image.open(img_path)
|
||||||
|
pil_img = pil_img.resize((size,size))
|
||||||
|
tk_picture = ImageTk.PhotoImage(pil_img)
|
||||||
|
pane.image = tk_picture
|
||||||
|
pane.configure(image=tk_picture)
|
||||||
|
return pil_img
|
||||||
|
|
||||||
|
def next_images():
|
||||||
|
global imgs_list, widgets, cur_img_1, cur_img_2, pil_img_1, pil_img_2
|
||||||
|
image_widget_1, image_widget_2, mode_label = widgets
|
||||||
|
|
||||||
|
cur_img_1, cur_img_2 = choices(imgs_list, k=2)
|
||||||
|
pil_img_1 = load_image_into_pane(cur_img_1, image_widget_1)
|
||||||
|
pil_img_2 = load_image_into_pane(cur_img_2, image_widget_2)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
os.makedirs(output_location, exist_ok=True)
|
||||||
|
|
||||||
|
window = tk.Tk()
|
||||||
|
window.title("Image pair labeler UI")
|
||||||
|
window.geometry('1024x620+100+100')
|
||||||
|
|
||||||
|
# Load images
|
||||||
|
imgs_list = glob("E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new\\*.jpg")
|
||||||
|
|
||||||
|
# Photo view.
|
||||||
|
image_widget_1 = tk.Label(window)
|
||||||
|
image_widget_1.place(x=0, y=0, width=512, height=512)
|
||||||
|
image_widget_2 = tk.Label(window)
|
||||||
|
image_widget_2.place(x=512, y=0, width=512, height=512)
|
||||||
|
|
||||||
|
# Labels
|
||||||
|
mode_label = tk.Label(window, text="", anchor="w")
|
||||||
|
mode_label.place(x=20, y=590, width=400, height=20)
|
||||||
|
|
||||||
|
widgets = (image_widget_1, image_widget_2, mode_label)
|
||||||
|
|
||||||
|
window.bind("<Tab>", key_press) # Skip current patch
|
||||||
|
window.bind("<Button-1>", click)
|
||||||
|
next_images()
|
||||||
|
update_mode_label()
|
||||||
|
window.mainloop()
|
48
codes/trainer/eval/single_point_pair_contrastive_eval.py
Normal file
48
codes/trainer/eval/single_point_pair_contrastive_eval.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import os.path as osp
|
||||||
|
import torchvision
|
||||||
|
from torch.nn import MSELoss
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import trainer.eval.evaluator as evaluator
|
||||||
|
from pytorch_fid import fid_score
|
||||||
|
|
||||||
|
from data.image_pair_with_corresponding_points_dataset import ImagePairWithCorrespondingPointsDataset
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
# Uses two datasets: a "similar" and "dissimilar" dataset, each of which contains pairs of images and similar/dissimilar
|
||||||
|
# points in those datasets. Uses the provided network to compute a latent vector for both similar and dissimilar.
|
||||||
|
# Reports a score for the l2 distance of both. A properly trained network will show similar points getting closer while
|
||||||
|
# dissimilar points remain constant or get further apart.
|
||||||
|
class SinglePointPairContrastiveEval(evaluator.Evaluator):
|
||||||
|
def __init__(self, model, opt_eval, env):
|
||||||
|
super().__init__(model, opt_eval, env)
|
||||||
|
self.batch_sz = opt_eval['batch_size']
|
||||||
|
self.eval_qty = opt_eval['quantity']
|
||||||
|
assert self.eval_qty % self.batch_sz == 0
|
||||||
|
self.similar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['similar_set_args']), shuffle=False, batch_size=self.batch_sz)
|
||||||
|
self.dissimilar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['dissimilar_set_args']), shuffle=False, batch_size=self.batch_sz)
|
||||||
|
|
||||||
|
def get_l2_score(self, dl):
|
||||||
|
distances = []
|
||||||
|
l2 = MSELoss()
|
||||||
|
for i, data in tqdm(enumerate(dl)):
|
||||||
|
latent1 = self.model(data['img1'], data['coords1'])
|
||||||
|
latent2 = self.model(data['img2'], data['coords2'])
|
||||||
|
distances.append(l2(latent1, latent2))
|
||||||
|
if i * self.batch_sz >= self.eval_qty:
|
||||||
|
break
|
||||||
|
|
||||||
|
return torch.stack(distances).mean()
|
||||||
|
|
||||||
|
def perform_eval(self):
|
||||||
|
self.model.eval()
|
||||||
|
print("Computing contrastive eval on similar set")
|
||||||
|
similars = self.get_l2_score(self.similar_set)
|
||||||
|
print("Computing contrastive eval on dissimilar set")
|
||||||
|
dissimilars = self.get_l2_score(self.dissimilar_set)
|
||||||
|
print(f"Eval done. val_similar_lq: {similars.item()}; val_dissimilar_l2: {dissimilars.item()}")
|
||||||
|
return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item()}
|
Loading…
Reference in New Issue
Block a user