Add an image patch labeling UI

This commit is contained in:
James Betker 2020-12-17 10:16:21 -07:00
parent daee1b5572
commit 12cf052889
5 changed files with 303 additions and 33 deletions

View File

@ -134,7 +134,7 @@ if __name__ == '__main__':
'corrupt_before_downsize': True,
'labeler': {
'type': 'patch_labels',
'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories.json'
'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json'
}
}

View File

@ -1,4 +1,5 @@
import os
from collections import OrderedDict
import orjson as json
# Given a JSON file produced by the VS.net image labeler utility, produces a dict where the keys are image file names
@ -9,39 +10,40 @@ import torch
class VsNetImageLabeler:
def __init__(self, label_file):
with open(label_file, "r") as read_file:
# Format of JSON file:
# "<nonsense>" {
# "label": "<label>"
# "keyBinding": "<nonsense>"
# "labeledImages": [
# { "path", "label", "patch_top", "patch_left", "patch_height", "patch_width" }
# ]
# }
categories = json.loads(read_file.read())
labeled_images = {}
available_labels = []
for cat in categories.values():
for lbli in cat['labeledImages']:
pth = lbli['path']
if pth not in labeled_images.keys():
labeled_images[pth] = []
labeled_images[pth].append(lbli)
if lbli['label'] not in available_labels:
available_labels.append(lbli['label'])
if not isinstance(label_file, list):
label_file = [label_file]
self.labeled_images = {}
for lfil in label_file:
with open(lfil, "r") as read_file:
self.label_file = label_file
# Format of JSON file:
# "key_binding" {
# "label": "<label>"
# "index": <num>
# "keyBinding": "key_binding"
# "labeledImages": [
# { "path", "label", "patch_top", "patch_left", "patch_height", "patch_width" }
# ]
# }
categories = json.loads(read_file.read())
available_labels = {}
label_value_dict = {}
for cat in categories.values():
available_labels[cat['index']] = cat['label']
label_value_dict[cat['label']] = cat['index']
for lbli in cat['labeledImages']:
pth = lbli['path']
if pth not in self.labeled_images.keys():
self.labeled_images[pth] = []
self.labeled_images[pth].append(lbli)
# Build the label values, from [1,inf]
label_value_dict = {}
for i, l in enumerate(available_labels):
label_value_dict[l] = i
# Insert "labelValue" for each entry.
for v in self.labeled_images.values():
for l in v:
l['labelValue'] = label_value_dict[l['label']]
# Insert "labelValue" for each entry.
for v in labeled_images.values():
for l in v:
l['labelValue'] = label_value_dict[l['label']]
self.labeled_images = labeled_images
self.str_labels = available_labels
self.categories = categories
self.str_labels = available_labels
def get_labeled_paths(self, base_path):
return [os.path.join(base_path, pth) for pth in self.labeled_images]
@ -57,4 +59,13 @@ class VsNetImageLabeler:
val = patch_lbl['labelValue']
labels[:,t:t+h,l:l+w] = val
mask[:,t:t+h,l:l+w] = 1.0
return labels, mask, self.str_labels
return labels, mask, self.str_labels
def add_label(self, binding, img_name, top, left, dim):
lbl = {"path": img_name, "label": self.categories[binding]['label'], "patch_top": top, "patch_left": left,
"patch_height": dim, "patch_width": dim}
self.categories[binding]['labeledImages'].append(lbl)
def save(self):
with open(self.label_file[0], "wb") as file:
file.write(json.dumps(self.categories))

View File

@ -0,0 +1,181 @@
# Script that builds and launches a tkinter UI for directly interacting with a pytorch model that performs
# patch-level image classification.
#
# This script is a bit rough around the edges. I threw it together quickly to ascertain its usefulness. If I end up
# using it a lot, I may re-visit it.
import os
import tkinter as tk
import torch
import torchvision
from PIL import ImageTk
from data.image_label_parser import VsNetImageLabeler
from scripts.ui.image_labeler.pretrained_image_patch_classifier import PretrainedImagePatchClassifier
# Globals used to define state that event handlers might operate on.
classifier = None
gen = None
labeler = None
to_pil = torchvision.transforms.ToPILImage()
widgets = None
batch_gen = None
cur_img = 0
batch_sz = 0
mode = 0
cur_path, cur_top, cur_left, cur_dim = None, None, None, None
pending_labels = []
def update_mode_label():
global widgets
image_widget, primary_label, secondary_label, mode_label = widgets
mode_label.config(text="Current mode: %s; Saved images: %i" % (labeler.str_labels[mode], len(pending_labels)))
# Handles the "change mode" hotkey. Changes the classification label being targeted.
def change_mode(event):
global mode, pending_labels
mode += 1
update_mode_label()
# Handles key presses, which are interpreted as requests to categorize a currently active image patch.
def key_press(event):
global batch_gen, labeler, pending_labels
if event.char != '\t':
if event.char not in labeler.categories.keys():
print("Specified category doesn't exist.")
return
cat = labeler.categories[event.char]
img_name = os.path.basename(cur_path)
print(cat['label'], img_name, cur_top, cur_left, cur_dim)
pending_labels.append([event.char, img_name, cur_top, cur_left, cur_dim])
if not next(batch_gen):
next_image(None)
update_mode_label()
# Pop the most recent label off of the stack.
def undo(event):
global pending_labels
c, nm, t, l, d = pending_labels.pop()
print("Removed pending label", c, nm, t, l, d)
update_mode_label()
# Save the stack of pending labels to the underlying label file.
def save(event):
global pending_labels, labeler
print("Saving %i labels", len(pending_labels,))
for l in pending_labels:
labeler.add_label(*l)
labeler.save()
pending_labels = []
update_mode_label()
# This is the main controller for the state machine that this UI attaches to. At its core, it performs inference on
# a batch of images, then repetitively yields patches of images that fit within a confidence bound for the currently
# active label.
def next_batch():
global gen, widgets, cur_img, batch_sz, cur_top, cur_left, cur_dim, mode, labeler, cur_path
image_widget, primary_label, secondary_label, mode_label = widgets
hq, res, data = next(gen)
scale = hq.shape[-1] // res.shape[-1]
# These are the confidence bounds. They apply to a post-softmax output. They are currently fixed.
conf_lower = .8
conf_upper = 1
valid_res = ((res > conf_lower) * (res < conf_upper)) * 1.0 # This results in a tensor of 0's where tensors are outside of the confidence bound, 1's where inside.
cur_img = 0
batch_sz = hq.shape[0]
while cur_img < batch_sz: # Note: cur_img can (intentionally) be changed outside of this loop.
# Build a random permutation for every image patch in the image. We will search for patches that fall within the confidence bound and yield them.
#permutation = torch.randperm(res.shape[-1] * res.shape[-2])
for p in range(res.shape[-1]*res.shape[-2]):
# Reconstruct a top & left coordinate.
t = p // res.shape[-1]
l = p % res.shape[-1]
if not valid_res[cur_img,mode,t,l]:
continue
# Build a mask that shows the user the underlying construction of the image.
# - Areas that don't fit the current labeling restrictions get a mask value of .25.
# - Areas that do fit but are not the current patch, get a mask value of .5
# - The current patch gets a mask value of 1.0
# Expected output shape is (1,h,w) so it can be multiplied into the output image.
mask = (valid_res[cur_img,mode,:,:].clone()*.25 + .25).unsqueeze(0)
mask[:,t,l] = 1.0
# Interpolate the mask so that it can be directly multiplied against the HQ image.
masked = hq[cur_img,:,:,:].clone() * torch.nn.functional.interpolate(mask.unsqueeze(0), scale_factor=scale, mode="nearest").squeeze(0)
# Update the image widget to show the new masked image.
tk_picture = ImageTk.PhotoImage(to_pil(masked))
image_widget.image = tk_picture
image_widget.configure(image=tk_picture)
# Fill in the labels
probs = res[cur_img, :, t, l]
probs, lblis = torch.topk(probs, k=2)
primary_label.config(text="%s (p=%f)" % (labeler.str_labels[lblis[0].item()], probs[0]))
secondary_label.config(text="%s (p=%f)" % (labeler.str_labels[lblis[1].item()], probs[1]))
# Update state variables so that the key handlers can save the current patch as needed.
cur_top, cur_left, cur_dim = (t*scale), (l*scale), scale
cur_path = os.path.basename(data['HQ_path'][cur_img])
yield True
cur_img += 1
cur_top, cur_left, cur_dim = None, None, None
return False
def next_image(event):
global batch_gen, batch_sz, cur_img
cur_img += 1
if cur_img >= batch_sz:
cur_img = 0
batch_gen = next_batch()
next(batch_gen)
if __name__ == '__main__':
classifier = PretrainedImagePatchClassifier('../options/train_imgset_structural_classifier.yml')
gen = classifier.get_next_sample()
labeler = VsNetImageLabeler('F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json')
window = tk.Tk()
window.title("Image labeler UI")
window.geometry('512x620+100+100')
# Photo view.
image_widget = tk.Label(window)
image_widget.place(x=0, y=0, width=512, height=512)
# Labels
primary_label = tk.Label(window, text="xxxx (p=1.0)", anchor="w")
primary_label.place(x=20, y=510, width=400, height=20)
secondary_label = tk.Label(window, text="yyyy (p=0.0)", anchor="w")
secondary_label.place(x=20, y=530, width=400, height=20)
help = tk.Label(window, text="Next: ctrl+f, Mode: ctrl+x, Undo: ctrl+z, Save: ctrl+s", anchor="w")
help.place(x=20, y=550, width=400, height=20)
help2 = tk.Label(window, text=','.join(list(labeler.categories.keys())), anchor="w")
help2.place(x=20, y=570, width=400, height=20)
mode_label = tk.Label(window, text="Current mode: %s; Saved images: %i" % (labeler.str_labels[mode], 0), anchor="w")
mode_label.place(x=20, y=590, width=400, height=20)
widgets = (image_widget, primary_label, secondary_label, mode_label)
window.bind("<Control-x>", change_mode)
window.bind("<Control-z>", undo)
window.bind("<Control-s>", save)
window.bind("<Control-f>", next_image)
for kb in labeler.categories.keys():
window.bind("%s" % (kb,), key_press)
window.bind("<Tab>", key_press) # Skip current patch
window.mainloop()

View File

@ -0,0 +1,52 @@
import logging
import os.path as osp
import utils
import utils.options as option
import utils.util as util
from data import create_dataset, create_dataloader
from models.ExtensibleTrainer import ExtensibleTrainer
class PretrainedImagePatchClassifier:
def __init__(self, cfg):
self.cfg = cfg
opt = option.parse(cfg, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
util.mkdirs(
(path for key, path in opt['path'].items()
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
#### Create test dataset and dataloader
dataset_opt = list(opt['datasets'].values())[0]
# Remove labeling features from the dataset config and wrappers.
if 'dataset' in dataset_opt.keys():
if 'labeler' in dataset_opt['dataset'].keys():
dataset_opt['dataset']['includes_labels'] = False
del dataset_opt['dataset']['labeler']
test_set = create_dataset(dataset_opt)
if hasattr(test_set, 'wrapped_dataset'):
test_set = test_set.wrapped_dataset
else:
test_set = create_dataset(dataset_opt)
logger.info('Number of test images: {:d}'.format(len(test_set)))
self.test_loader = create_dataloader(test_set, dataset_opt, opt)
self.model = ExtensibleTrainer(opt)
self.gen = self.model.netsG['generator']
self.dataset_dir = osp.join(opt['path']['results_root'], opt['name'])
util.mkdir(self.dataset_dir)
def get_next_sample(self):
for data in self.test_loader:
hq = data['hq'].to('cuda')
res = self.gen(hq)
yield hq, res, data

View File

@ -0,0 +1,26 @@
import argparse
import os
import torch
import torchvision
import utils.options as option
from scripts.ui.image_labeler.pretrained_image_patch_classifier import PretrainedImagePatchClassifier
if __name__ == "__main__":
#### options
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/train_imgset_structural_classifier.yml')
classifier = PretrainedImagePatchClassifier(parser.parse_args().opt)
label_to_search_for = 4
step = 1
for hq, res in classifier.get_next_sample():
res = torch.nn.functional.interpolate(res, size=hq.shape[2:], mode="nearest")
res_lbl = res[:, label_to_search_for, :, :].unsqueeze(1)
res_lbl_mask = (1.0 * (res_lbl > .5))*.5 + .5
hq = hq * res_lbl_mask
torchvision.utils.save_image(hq, os.path.join(classifier.dataset_dir, "%i.png" % (step,)))
step += 1