# Script that builds and launches a tkinter UI for directly interacting with a pytorch model that performs # patch-level image classification. # # This script is a bit rough around the edges. I threw it together quickly to ascertain its usefulness. If I end up # using it a lot, I may re-visit it. import os import tkinter as tk import torch import torchvision from PIL import ImageTk from data.image_label_parser import VsNetImageLabeler from scripts.ui.image_labeler.pretrained_image_patch_classifier import PretrainedImagePatchClassifier # Globals used to define state that event handlers might operate on. classifier = None gen = None labeler = None to_pil = torchvision.transforms.ToPILImage() widgets = None batch_gen = None cur_img = 0 batch_sz = 0 mode = 0 cur_path, cur_top, cur_left, cur_dim = None, None, None, None pending_labels = [] def update_mode_label(): global widgets image_widget, primary_label, secondary_label, mode_label = widgets mode_label.config(text="Current mode: %s; Saved images: %i" % (labeler.str_labels[mode], len(pending_labels))) # Handles the "change mode" hotkey. Changes the classification label being targeted. def change_mode(event): global mode, pending_labels mode += 1 update_mode_label() # Handles key presses, which are interpreted as requests to categorize a currently active image patch. def key_press(event): global batch_gen, labeler, pending_labels if event.char != '\t': if event.char not in labeler.categories.keys(): print("Specified category doesn't exist.") return cat = labeler.categories[event.char] img_name = os.path.basename(cur_path) print(cat['label'], img_name, cur_top, cur_left, cur_dim) pending_labels.append([event.char, img_name, cur_top, cur_left, cur_dim]) if not next(batch_gen): next_image(None) update_mode_label() # Pop the most recent label off of the stack. def undo(event): global pending_labels c, nm, t, l, d = pending_labels.pop() print("Removed pending label", c, nm, t, l, d) update_mode_label() # Save the stack of pending labels to the underlying label file. def save(event): global pending_labels, labeler print("Saving %i labels", len(pending_labels,)) for l in pending_labels: labeler.add_label(*l) labeler.save() pending_labels = [] update_mode_label() # This is the main controller for the state machine that this UI attaches to. At its core, it performs inference on # a batch of images, then repetitively yields patches of images that fit within a confidence bound for the currently # active label. def next_batch(): global gen, widgets, cur_img, batch_sz, cur_top, cur_left, cur_dim, mode, labeler, cur_path image_widget, primary_label, secondary_label, mode_label = widgets hq, res, data = next(gen) scale = hq.shape[-1] // res.shape[-1] # These are the confidence bounds. They apply to a post-softmax output. They are currently fixed. conf_lower = .8 conf_upper = 1 valid_res = ((res > conf_lower) * (res < conf_upper)) * 1.0 # This results in a tensor of 0's where tensors are outside of the confidence bound, 1's where inside. cur_img = 0 batch_sz = hq.shape[0] while cur_img < batch_sz: # Note: cur_img can (intentionally) be changed outside of this loop. # Build a random permutation for every image patch in the image. We will search for patches that fall within the confidence bound and yield them. #permutation = torch.randperm(res.shape[-1] * res.shape[-2]) for p in range(res.shape[-1]*res.shape[-2]): # Reconstruct a top & left coordinate. t = p // res.shape[-1] l = p % res.shape[-1] if not valid_res[cur_img,mode,t,l]: continue # Build a mask that shows the user the underlying construction of the image. # - Areas that don't fit the current labeling restrictions get a mask value of .25. # - Areas that do fit but are not the current patch, get a mask value of .5 # - The current patch gets a mask value of 1.0 # Expected output shape is (1,h,w) so it can be multiplied into the output image. mask = (valid_res[cur_img,mode,:,:].clone()*.25 + .25).unsqueeze(0) mask[:,t,l] = 1.0 # Interpolate the mask so that it can be directly multiplied against the HQ image. masked = hq[cur_img,:,:,:].clone() * torch.nn.functional.interpolate(mask.unsqueeze(0), scale_factor=scale, mode="nearest").squeeze(0) # Update the image widget to show the new masked image. tk_picture = ImageTk.PhotoImage(to_pil(masked)) image_widget.image = tk_picture image_widget.configure(image=tk_picture) # Fill in the labels probs = res[cur_img, :, t, l] probs, lblis = torch.topk(probs, k=2) primary_label.config(text="%s (p=%f)" % (labeler.str_labels[lblis[0].item()], probs[0])) secondary_label.config(text="%s (p=%f)" % (labeler.str_labels[lblis[1].item()], probs[1])) # Update state variables so that the key handlers can save the current patch as needed. cur_top, cur_left, cur_dim = (t*scale), (l*scale), scale cur_path = os.path.basename(data['HQ_path'][cur_img]) yield True cur_img += 1 cur_top, cur_left, cur_dim = None, None, None return False def next_image(event): global batch_gen, batch_sz, cur_img cur_img += 1 if cur_img >= batch_sz: cur_img = 0 batch_gen = next_batch() next(batch_gen) if __name__ == '__main__': classifier = PretrainedImagePatchClassifier('../options/train_imgset_structural_classifier.yml') gen = classifier.get_next_sample() labeler = VsNetImageLabeler('F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json') window = tk.Tk() window.title("Image labeler UI") window.geometry('512x620+100+100') # Photo view. image_widget = tk.Label(window) image_widget.place(x=0, y=0, width=512, height=512) # Labels primary_label = tk.Label(window, text="xxxx (p=1.0)", anchor="w") primary_label.place(x=20, y=510, width=400, height=20) secondary_label = tk.Label(window, text="yyyy (p=0.0)", anchor="w") secondary_label.place(x=20, y=530, width=400, height=20) help = tk.Label(window, text="Next: ctrl+f, Mode: ctrl+x, Undo: ctrl+z, Save: ctrl+s", anchor="w") help.place(x=20, y=550, width=400, height=20) help2 = tk.Label(window, text=','.join(list(labeler.categories.keys())), anchor="w") help2.place(x=20, y=570, width=400, height=20) mode_label = tk.Label(window, text="Current mode: %s; Saved images: %i" % (labeler.str_labels[mode], 0), anchor="w") mode_label.place(x=20, y=590, width=400, height=20) widgets = (image_widget, primary_label, secondary_label, mode_label) window.bind("", change_mode) window.bind("", undo) window.bind("", save) window.bind("", next_image) for kb in labeler.categories.keys(): window.bind("%s" % (kb,), key_press) window.bind("", key_press) # Skip current patch window.mainloop()