forked from mrq/DL-Art-School
105 lines
3.1 KiB
Python
105 lines
3.1 KiB
Python
# Script that builds and launches a tkinter UI for labeling similar points between two images.
|
|
import os
|
|
import tkinter as tk
|
|
from glob import glob
|
|
from random import choices
|
|
|
|
import torch
|
|
from PIL import ImageTk, Image
|
|
|
|
# Globals used to define state that event handlers might operate on.
|
|
imgs_list = []
|
|
widgets = None
|
|
cur_img_1, cur_img_2 = None, None
|
|
pil_img_1, pil_img_2 = None, None
|
|
pending_labels = []
|
|
mode_select_image_1 = True
|
|
img_count = 1
|
|
img_loc_1 = None
|
|
output_location = "results"
|
|
|
|
|
|
def update_mode_label():
|
|
global widgets, mode_select_image_1, img_count
|
|
image_widget_1, image_widget_2, mode_label = widgets
|
|
mode_str = "Select point in image 1" if mode_select_image_1 else "Select point in image 2"
|
|
mode_label.config(text="%s; Saved images: %i" % (mode_str, img_count))
|
|
|
|
|
|
# Handles key presses, which are interpreted as requests to categorize a currently active image patch.
|
|
def key_press(event):
|
|
global batch_gen, labeler, pending_labels
|
|
|
|
if event.char == '\t':
|
|
next_images()
|
|
|
|
update_mode_label()
|
|
|
|
|
|
def click(event):
|
|
global img_loc_1, mode_select_image_1, pil_img_1, pil_img_2, img_count
|
|
x, y = event.x, event.y
|
|
if x > 512 or y > 512:
|
|
print(f"Bounds error {x} {y}")
|
|
return
|
|
|
|
print(f"Detected click. {x} {y}")
|
|
if mode_select_image_1:
|
|
img_loc_1 = x, y
|
|
mode_select_image_1 = False
|
|
else:
|
|
ofolder = f'{output_location}/{img_count}'
|
|
os.makedirs(ofolder)
|
|
pil_img_1.save(os.path.join(ofolder, "1.jpg"))
|
|
pil_img_2.save(os.path.join(ofolder, "2.jpg"))
|
|
torch.save([img_loc_1, (x,y)], os.path.join(ofolder, "coords.pth"))
|
|
img_count = img_count + 1
|
|
mode_select_image_1 = True
|
|
next_images()
|
|
update_mode_label()
|
|
|
|
|
|
def load_image_into_pane(img_path, pane, size=512):
|
|
pil_img = Image.open(img_path)
|
|
pil_img = pil_img.resize((size,size))
|
|
tk_picture = ImageTk.PhotoImage(pil_img)
|
|
pane.image = tk_picture
|
|
pane.configure(image=tk_picture)
|
|
return pil_img
|
|
|
|
def next_images():
|
|
global imgs_list, widgets, cur_img_1, cur_img_2, pil_img_1, pil_img_2
|
|
image_widget_1, image_widget_2, mode_label = widgets
|
|
|
|
cur_img_1, cur_img_2 = choices(imgs_list, k=2)
|
|
pil_img_1 = load_image_into_pane(cur_img_1, image_widget_1)
|
|
pil_img_2 = load_image_into_pane(cur_img_2, image_widget_2)
|
|
|
|
if __name__ == '__main__':
|
|
os.makedirs(output_location, exist_ok=True)
|
|
|
|
window = tk.Tk()
|
|
window.title("Image pair labeler UI")
|
|
window.geometry('1024x620+100+100')
|
|
|
|
# Load images
|
|
imgs_list = glob("E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new\\*.jpg")
|
|
|
|
# Photo view.
|
|
image_widget_1 = tk.Label(window)
|
|
image_widget_1.place(x=0, y=0, width=512, height=512)
|
|
image_widget_2 = tk.Label(window)
|
|
image_widget_2.place(x=512, y=0, width=512, height=512)
|
|
|
|
# Labels
|
|
mode_label = tk.Label(window, text="", anchor="w")
|
|
mode_label.place(x=20, y=590, width=400, height=20)
|
|
|
|
widgets = (image_widget_1, image_widget_2, mode_label)
|
|
|
|
window.bind("<Tab>", key_press) # Skip current patch
|
|
window.bind("<Button-1>", click)
|
|
next_images()
|
|
update_mode_label()
|
|
window.mainloop()
|