forked from mrq/DL-Art-School
Add dataset, ui for labeling and evaluator for pointwise classification
This commit is contained in:
parent
fc623d4b5a
commit
23e01314d4
84
codes/data/image_pair_with_corresponding_points_dataset.py
Normal file
84
codes/data/image_pair_with_corresponding_points_dataset.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
import glob
|
||||
import itertools
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import kornia
|
||||
import numpy as np
|
||||
import pytorch_ssim
|
||||
import torch
|
||||
import os
|
||||
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import Normalize
|
||||
from tqdm import tqdm
|
||||
|
||||
from data import util
|
||||
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
|
||||
from data.image_corruptor import ImageCorruptor
|
||||
from data.image_label_parser import VsNetImageLabeler
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class ImagePairWithCorrespondingPointsDataset(Dataset):
|
||||
def __init__(self, opt):
|
||||
self.opt = opt
|
||||
self.path = opt['path']
|
||||
self.pairs = list(filter(lambda f: not os.path.isdir(f), os.listdir(self.path)))
|
||||
self.transforms = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
self.size = opt['size']
|
||||
|
||||
|
||||
def __getitem__(self, item):
|
||||
dir = self.pairs[item]
|
||||
img1 = self.transforms(Image.open(os.path.join(self.path, dir, "1.jpg")))
|
||||
img2 = self.transforms(Image.open(os.path.join(self.path, dir, "2.jpg")))
|
||||
coords1, coords2 = torch.load(os.path.join(self.path, dir, "coords.pth"))
|
||||
assert img1.shape[-2] == img1.shape[-1]
|
||||
assert img2.shape[-2] == img2.shape[-1]
|
||||
if img1.shape[-1] != self.size:
|
||||
scale = img1.shape[-1] / self.size
|
||||
assert(int(scale) == scale) # We will only downsample to even resolutions.
|
||||
scale = 1 / scale
|
||||
img1 = torch.nn.functional.interpolate(img1.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
|
||||
coords1 = [int(c * scale) for c in coords1]
|
||||
if img2.shape[-1] != self.size:
|
||||
scale = img2.shape[-1] / self.size
|
||||
assert(int(scale) == scale) # We will only downsample to even resolutions.
|
||||
scale = 1 / scale
|
||||
img2 = torch.nn.functional.interpolate(img2.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
|
||||
coords2 = [int(c * scale) for c in coords2]
|
||||
coords1 = (coords1[1], coords1[0]) # The UI puts these out backwards (x,y). Flip them.
|
||||
coords2 = (coords2[1], coords2[0])
|
||||
return {
|
||||
'img1': img1,
|
||||
'img2': img2,
|
||||
'coords1': coords1,
|
||||
'coords2': coords2
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'path': 'F:\\dlas\\codes\\scripts\\ui\\image_pair_labeler\\results',
|
||||
'size': 256
|
||||
}
|
||||
output_path = '.'
|
||||
|
||||
ds = DataLoader(ImagePairWithCorrespondingPointsDataset(opt), shuffle=True, num_workers=0)
|
||||
for i, d in tqdm(enumerate(ds)):
|
||||
i1 = d['img1']
|
||||
i2 = d['img2']
|
||||
c1 = d['coords1']
|
||||
c2 = d['coords2']
|
||||
i1[:,:,c1[0]-3:c1[0]+3,c1[1]-3:c1[1]+3] = 0
|
||||
i2[:,:,c2[0]-3:c2[0]+3,c2[1]-3:c2[1]+3] = 0
|
||||
torchvision.utils.save_image(i1, f'{output_path}\\{i}_1.png')
|
||||
torchvision.utils.save_image(i2, f'{output_path}\\{i}_2.png')
|
104
codes/scripts/ui/image_pair_labeler/image_pair_ui.py
Normal file
104
codes/scripts/ui/image_pair_labeler/image_pair_ui.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
# Script that builds and launches a tkinter UI for labeling similar points between two images.
|
||||
import os
|
||||
import tkinter as tk
|
||||
from glob import glob
|
||||
from random import choices
|
||||
|
||||
import torch
|
||||
from PIL import ImageTk, Image
|
||||
|
||||
# Globals used to define state that event handlers might operate on.
|
||||
imgs_list = []
|
||||
widgets = None
|
||||
cur_img_1, cur_img_2 = None, None
|
||||
pil_img_1, pil_img_2 = None, None
|
||||
pending_labels = []
|
||||
mode_select_image_1 = True
|
||||
img_count = 1
|
||||
img_loc_1 = None
|
||||
output_location = "results"
|
||||
|
||||
|
||||
def update_mode_label():
|
||||
global widgets, mode_select_image_1, img_count
|
||||
image_widget_1, image_widget_2, mode_label = widgets
|
||||
mode_str = "Select point in image 1" if mode_select_image_1 else "Select point in image 2"
|
||||
mode_label.config(text="%s; Saved images: %i" % (mode_str, img_count))
|
||||
|
||||
|
||||
# Handles key presses, which are interpreted as requests to categorize a currently active image patch.
|
||||
def key_press(event):
|
||||
global batch_gen, labeler, pending_labels
|
||||
|
||||
if event.char == '\t':
|
||||
next_images()
|
||||
|
||||
update_mode_label()
|
||||
|
||||
|
||||
def click(event):
|
||||
global img_loc_1, mode_select_image_1, pil_img_1, pil_img_2, img_count
|
||||
x, y = event.x, event.y
|
||||
if x > 512 or y > 512:
|
||||
print(f"Bounds error {x} {y}")
|
||||
return
|
||||
|
||||
print(f"Detected click. {x} {y}")
|
||||
if mode_select_image_1:
|
||||
img_loc_1 = x, y
|
||||
mode_select_image_1 = False
|
||||
else:
|
||||
ofolder = f'{output_location}/{img_count}'
|
||||
os.makedirs(ofolder)
|
||||
pil_img_1.save(os.path.join(ofolder, "1.jpg"))
|
||||
pil_img_2.save(os.path.join(ofolder, "2.jpg"))
|
||||
torch.save([img_loc_1, (x,y)], os.path.join(ofolder, "coords.pth"))
|
||||
img_count = img_count + 1
|
||||
mode_select_image_1 = True
|
||||
next_images()
|
||||
update_mode_label()
|
||||
|
||||
|
||||
def load_image_into_pane(img_path, pane, size=512):
|
||||
pil_img = Image.open(img_path)
|
||||
pil_img = pil_img.resize((size,size))
|
||||
tk_picture = ImageTk.PhotoImage(pil_img)
|
||||
pane.image = tk_picture
|
||||
pane.configure(image=tk_picture)
|
||||
return pil_img
|
||||
|
||||
def next_images():
|
||||
global imgs_list, widgets, cur_img_1, cur_img_2, pil_img_1, pil_img_2
|
||||
image_widget_1, image_widget_2, mode_label = widgets
|
||||
|
||||
cur_img_1, cur_img_2 = choices(imgs_list, k=2)
|
||||
pil_img_1 = load_image_into_pane(cur_img_1, image_widget_1)
|
||||
pil_img_2 = load_image_into_pane(cur_img_2, image_widget_2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.makedirs(output_location, exist_ok=True)
|
||||
|
||||
window = tk.Tk()
|
||||
window.title("Image pair labeler UI")
|
||||
window.geometry('1024x620+100+100')
|
||||
|
||||
# Load images
|
||||
imgs_list = glob("E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new\\*.jpg")
|
||||
|
||||
# Photo view.
|
||||
image_widget_1 = tk.Label(window)
|
||||
image_widget_1.place(x=0, y=0, width=512, height=512)
|
||||
image_widget_2 = tk.Label(window)
|
||||
image_widget_2.place(x=512, y=0, width=512, height=512)
|
||||
|
||||
# Labels
|
||||
mode_label = tk.Label(window, text="", anchor="w")
|
||||
mode_label.place(x=20, y=590, width=400, height=20)
|
||||
|
||||
widgets = (image_widget_1, image_widget_2, mode_label)
|
||||
|
||||
window.bind("<Tab>", key_press) # Skip current patch
|
||||
window.bind("<Button-1>", click)
|
||||
next_images()
|
||||
update_mode_label()
|
||||
window.mainloop()
|
48
codes/trainer/eval/single_point_pair_contrastive_eval.py
Normal file
48
codes/trainer/eval/single_point_pair_contrastive_eval.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import os.path as osp
|
||||
import torchvision
|
||||
from torch.nn import MSELoss
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
import trainer.eval.evaluator as evaluator
|
||||
from pytorch_fid import fid_score
|
||||
|
||||
from data.image_pair_with_corresponding_points_dataset import ImagePairWithCorrespondingPointsDataset
|
||||
from utils.util import opt_get
|
||||
|
||||
# Uses two datasets: a "similar" and "dissimilar" dataset, each of which contains pairs of images and similar/dissimilar
|
||||
# points in those datasets. Uses the provided network to compute a latent vector for both similar and dissimilar.
|
||||
# Reports a score for the l2 distance of both. A properly trained network will show similar points getting closer while
|
||||
# dissimilar points remain constant or get further apart.
|
||||
class SinglePointPairContrastiveEval(evaluator.Evaluator):
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env)
|
||||
self.batch_sz = opt_eval['batch_size']
|
||||
self.eval_qty = opt_eval['quantity']
|
||||
assert self.eval_qty % self.batch_sz == 0
|
||||
self.similar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['similar_set_args']), shuffle=False, batch_size=self.batch_sz)
|
||||
self.dissimilar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['dissimilar_set_args']), shuffle=False, batch_size=self.batch_sz)
|
||||
|
||||
def get_l2_score(self, dl):
|
||||
distances = []
|
||||
l2 = MSELoss()
|
||||
for i, data in tqdm(enumerate(dl)):
|
||||
latent1 = self.model(data['img1'], data['coords1'])
|
||||
latent2 = self.model(data['img2'], data['coords2'])
|
||||
distances.append(l2(latent1, latent2))
|
||||
if i * self.batch_sz >= self.eval_qty:
|
||||
break
|
||||
|
||||
return torch.stack(distances).mean()
|
||||
|
||||
def perform_eval(self):
|
||||
self.model.eval()
|
||||
print("Computing contrastive eval on similar set")
|
||||
similars = self.get_l2_score(self.similar_set)
|
||||
print("Computing contrastive eval on dissimilar set")
|
||||
dissimilars = self.get_l2_score(self.dissimilar_set)
|
||||
print(f"Eval done. val_similar_lq: {similars.item()}; val_dissimilar_l2: {dissimilars.item()}")
|
||||
return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item()}
|
Loading…
Reference in New Issue
Block a user