Add an image patch labeling UI
This commit is contained in:
parent
daee1b5572
commit
12cf052889
|
@ -134,7 +134,7 @@ if __name__ == '__main__':
|
||||||
'corrupt_before_downsize': True,
|
'corrupt_before_downsize': True,
|
||||||
'labeler': {
|
'labeler': {
|
||||||
'type': 'patch_labels',
|
'type': 'patch_labels',
|
||||||
'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories.json'
|
'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import orjson as json
|
import orjson as json
|
||||||
# Given a JSON file produced by the VS.net image labeler utility, produces a dict where the keys are image file names
|
# Given a JSON file produced by the VS.net image labeler utility, produces a dict where the keys are image file names
|
||||||
|
@ -9,39 +10,40 @@ import torch
|
||||||
|
|
||||||
class VsNetImageLabeler:
|
class VsNetImageLabeler:
|
||||||
def __init__(self, label_file):
|
def __init__(self, label_file):
|
||||||
with open(label_file, "r") as read_file:
|
if not isinstance(label_file, list):
|
||||||
# Format of JSON file:
|
label_file = [label_file]
|
||||||
# "<nonsense>" {
|
self.labeled_images = {}
|
||||||
# "label": "<label>"
|
for lfil in label_file:
|
||||||
# "keyBinding": "<nonsense>"
|
with open(lfil, "r") as read_file:
|
||||||
# "labeledImages": [
|
self.label_file = label_file
|
||||||
# { "path", "label", "patch_top", "patch_left", "patch_height", "patch_width" }
|
# Format of JSON file:
|
||||||
# ]
|
# "key_binding" {
|
||||||
# }
|
# "label": "<label>"
|
||||||
categories = json.loads(read_file.read())
|
# "index": <num>
|
||||||
labeled_images = {}
|
# "keyBinding": "key_binding"
|
||||||
available_labels = []
|
# "labeledImages": [
|
||||||
for cat in categories.values():
|
# { "path", "label", "patch_top", "patch_left", "patch_height", "patch_width" }
|
||||||
for lbli in cat['labeledImages']:
|
# ]
|
||||||
pth = lbli['path']
|
# }
|
||||||
if pth not in labeled_images.keys():
|
categories = json.loads(read_file.read())
|
||||||
labeled_images[pth] = []
|
available_labels = {}
|
||||||
labeled_images[pth].append(lbli)
|
label_value_dict = {}
|
||||||
if lbli['label'] not in available_labels:
|
for cat in categories.values():
|
||||||
available_labels.append(lbli['label'])
|
available_labels[cat['index']] = cat['label']
|
||||||
|
label_value_dict[cat['label']] = cat['index']
|
||||||
|
for lbli in cat['labeledImages']:
|
||||||
|
pth = lbli['path']
|
||||||
|
if pth not in self.labeled_images.keys():
|
||||||
|
self.labeled_images[pth] = []
|
||||||
|
self.labeled_images[pth].append(lbli)
|
||||||
|
|
||||||
# Build the label values, from [1,inf]
|
# Insert "labelValue" for each entry.
|
||||||
label_value_dict = {}
|
for v in self.labeled_images.values():
|
||||||
for i, l in enumerate(available_labels):
|
for l in v:
|
||||||
label_value_dict[l] = i
|
l['labelValue'] = label_value_dict[l['label']]
|
||||||
|
|
||||||
# Insert "labelValue" for each entry.
|
self.categories = categories
|
||||||
for v in labeled_images.values():
|
self.str_labels = available_labels
|
||||||
for l in v:
|
|
||||||
l['labelValue'] = label_value_dict[l['label']]
|
|
||||||
|
|
||||||
self.labeled_images = labeled_images
|
|
||||||
self.str_labels = available_labels
|
|
||||||
|
|
||||||
def get_labeled_paths(self, base_path):
|
def get_labeled_paths(self, base_path):
|
||||||
return [os.path.join(base_path, pth) for pth in self.labeled_images]
|
return [os.path.join(base_path, pth) for pth in self.labeled_images]
|
||||||
|
@ -57,4 +59,13 @@ class VsNetImageLabeler:
|
||||||
val = patch_lbl['labelValue']
|
val = patch_lbl['labelValue']
|
||||||
labels[:,t:t+h,l:l+w] = val
|
labels[:,t:t+h,l:l+w] = val
|
||||||
mask[:,t:t+h,l:l+w] = 1.0
|
mask[:,t:t+h,l:l+w] = 1.0
|
||||||
return labels, mask, self.str_labels
|
return labels, mask, self.str_labels
|
||||||
|
|
||||||
|
def add_label(self, binding, img_name, top, left, dim):
|
||||||
|
lbl = {"path": img_name, "label": self.categories[binding]['label'], "patch_top": top, "patch_left": left,
|
||||||
|
"patch_height": dim, "patch_width": dim}
|
||||||
|
self.categories[binding]['labeledImages'].append(lbl)
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
with open(self.label_file[0], "wb") as file:
|
||||||
|
file.write(json.dumps(self.categories))
|
||||||
|
|
181
codes/scripts/ui/image_labeler/image_labeler_ui.py
Normal file
181
codes/scripts/ui/image_labeler/image_labeler_ui.py
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
# Script that builds and launches a tkinter UI for directly interacting with a pytorch model that performs
|
||||||
|
# patch-level image classification.
|
||||||
|
#
|
||||||
|
# This script is a bit rough around the edges. I threw it together quickly to ascertain its usefulness. If I end up
|
||||||
|
# using it a lot, I may re-visit it.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tkinter as tk
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from PIL import ImageTk
|
||||||
|
|
||||||
|
from data.image_label_parser import VsNetImageLabeler
|
||||||
|
from scripts.ui.image_labeler.pretrained_image_patch_classifier import PretrainedImagePatchClassifier
|
||||||
|
|
||||||
|
# Globals used to define state that event handlers might operate on.
|
||||||
|
classifier = None
|
||||||
|
gen = None
|
||||||
|
labeler = None
|
||||||
|
to_pil = torchvision.transforms.ToPILImage()
|
||||||
|
widgets = None
|
||||||
|
batch_gen = None
|
||||||
|
cur_img = 0
|
||||||
|
batch_sz = 0
|
||||||
|
mode = 0
|
||||||
|
cur_path, cur_top, cur_left, cur_dim = None, None, None, None
|
||||||
|
pending_labels = []
|
||||||
|
|
||||||
|
|
||||||
|
def update_mode_label():
|
||||||
|
global widgets
|
||||||
|
image_widget, primary_label, secondary_label, mode_label = widgets
|
||||||
|
mode_label.config(text="Current mode: %s; Saved images: %i" % (labeler.str_labels[mode], len(pending_labels)))
|
||||||
|
|
||||||
|
|
||||||
|
# Handles the "change mode" hotkey. Changes the classification label being targeted.
|
||||||
|
def change_mode(event):
|
||||||
|
global mode, pending_labels
|
||||||
|
mode += 1
|
||||||
|
update_mode_label()
|
||||||
|
|
||||||
|
|
||||||
|
# Handles key presses, which are interpreted as requests to categorize a currently active image patch.
|
||||||
|
def key_press(event):
|
||||||
|
global batch_gen, labeler, pending_labels
|
||||||
|
|
||||||
|
if event.char != '\t':
|
||||||
|
if event.char not in labeler.categories.keys():
|
||||||
|
print("Specified category doesn't exist.")
|
||||||
|
return
|
||||||
|
cat = labeler.categories[event.char]
|
||||||
|
img_name = os.path.basename(cur_path)
|
||||||
|
print(cat['label'], img_name, cur_top, cur_left, cur_dim)
|
||||||
|
pending_labels.append([event.char, img_name, cur_top, cur_left, cur_dim])
|
||||||
|
|
||||||
|
if not next(batch_gen):
|
||||||
|
next_image(None)
|
||||||
|
update_mode_label()
|
||||||
|
|
||||||
|
|
||||||
|
# Pop the most recent label off of the stack.
|
||||||
|
def undo(event):
|
||||||
|
global pending_labels
|
||||||
|
c, nm, t, l, d = pending_labels.pop()
|
||||||
|
print("Removed pending label", c, nm, t, l, d)
|
||||||
|
update_mode_label()
|
||||||
|
|
||||||
|
|
||||||
|
# Save the stack of pending labels to the underlying label file.
|
||||||
|
def save(event):
|
||||||
|
global pending_labels, labeler
|
||||||
|
print("Saving %i labels", len(pending_labels,))
|
||||||
|
for l in pending_labels:
|
||||||
|
labeler.add_label(*l)
|
||||||
|
labeler.save()
|
||||||
|
pending_labels = []
|
||||||
|
update_mode_label()
|
||||||
|
|
||||||
|
|
||||||
|
# This is the main controller for the state machine that this UI attaches to. At its core, it performs inference on
|
||||||
|
# a batch of images, then repetitively yields patches of images that fit within a confidence bound for the currently
|
||||||
|
# active label.
|
||||||
|
def next_batch():
|
||||||
|
global gen, widgets, cur_img, batch_sz, cur_top, cur_left, cur_dim, mode, labeler, cur_path
|
||||||
|
image_widget, primary_label, secondary_label, mode_label = widgets
|
||||||
|
hq, res, data = next(gen)
|
||||||
|
scale = hq.shape[-1] // res.shape[-1]
|
||||||
|
|
||||||
|
# These are the confidence bounds. They apply to a post-softmax output. They are currently fixed.
|
||||||
|
conf_lower = .8
|
||||||
|
conf_upper = 1
|
||||||
|
valid_res = ((res > conf_lower) * (res < conf_upper)) * 1.0 # This results in a tensor of 0's where tensors are outside of the confidence bound, 1's where inside.
|
||||||
|
|
||||||
|
cur_img = 0
|
||||||
|
batch_sz = hq.shape[0]
|
||||||
|
while cur_img < batch_sz: # Note: cur_img can (intentionally) be changed outside of this loop.
|
||||||
|
# Build a random permutation for every image patch in the image. We will search for patches that fall within the confidence bound and yield them.
|
||||||
|
#permutation = torch.randperm(res.shape[-1] * res.shape[-2])
|
||||||
|
for p in range(res.shape[-1]*res.shape[-2]):
|
||||||
|
# Reconstruct a top & left coordinate.
|
||||||
|
t = p // res.shape[-1]
|
||||||
|
l = p % res.shape[-1]
|
||||||
|
if not valid_res[cur_img,mode,t,l]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build a mask that shows the user the underlying construction of the image.
|
||||||
|
# - Areas that don't fit the current labeling restrictions get a mask value of .25.
|
||||||
|
# - Areas that do fit but are not the current patch, get a mask value of .5
|
||||||
|
# - The current patch gets a mask value of 1.0
|
||||||
|
# Expected output shape is (1,h,w) so it can be multiplied into the output image.
|
||||||
|
mask = (valid_res[cur_img,mode,:,:].clone()*.25 + .25).unsqueeze(0)
|
||||||
|
mask[:,t,l] = 1.0
|
||||||
|
|
||||||
|
# Interpolate the mask so that it can be directly multiplied against the HQ image.
|
||||||
|
masked = hq[cur_img,:,:,:].clone() * torch.nn.functional.interpolate(mask.unsqueeze(0), scale_factor=scale, mode="nearest").squeeze(0)
|
||||||
|
|
||||||
|
# Update the image widget to show the new masked image.
|
||||||
|
tk_picture = ImageTk.PhotoImage(to_pil(masked))
|
||||||
|
image_widget.image = tk_picture
|
||||||
|
image_widget.configure(image=tk_picture)
|
||||||
|
|
||||||
|
# Fill in the labels
|
||||||
|
probs = res[cur_img, :, t, l]
|
||||||
|
probs, lblis = torch.topk(probs, k=2)
|
||||||
|
primary_label.config(text="%s (p=%f)" % (labeler.str_labels[lblis[0].item()], probs[0]))
|
||||||
|
secondary_label.config(text="%s (p=%f)" % (labeler.str_labels[lblis[1].item()], probs[1]))
|
||||||
|
|
||||||
|
# Update state variables so that the key handlers can save the current patch as needed.
|
||||||
|
cur_top, cur_left, cur_dim = (t*scale), (l*scale), scale
|
||||||
|
cur_path = os.path.basename(data['HQ_path'][cur_img])
|
||||||
|
yield True
|
||||||
|
cur_img += 1
|
||||||
|
cur_top, cur_left, cur_dim = None, None, None
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def next_image(event):
|
||||||
|
global batch_gen, batch_sz, cur_img
|
||||||
|
cur_img += 1
|
||||||
|
if cur_img >= batch_sz:
|
||||||
|
cur_img = 0
|
||||||
|
batch_gen = next_batch()
|
||||||
|
next(batch_gen)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
classifier = PretrainedImagePatchClassifier('../options/train_imgset_structural_classifier.yml')
|
||||||
|
gen = classifier.get_next_sample()
|
||||||
|
labeler = VsNetImageLabeler('F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json')
|
||||||
|
|
||||||
|
window = tk.Tk()
|
||||||
|
window.title("Image labeler UI")
|
||||||
|
window.geometry('512x620+100+100')
|
||||||
|
|
||||||
|
# Photo view.
|
||||||
|
image_widget = tk.Label(window)
|
||||||
|
image_widget.place(x=0, y=0, width=512, height=512)
|
||||||
|
|
||||||
|
# Labels
|
||||||
|
primary_label = tk.Label(window, text="xxxx (p=1.0)", anchor="w")
|
||||||
|
primary_label.place(x=20, y=510, width=400, height=20)
|
||||||
|
secondary_label = tk.Label(window, text="yyyy (p=0.0)", anchor="w")
|
||||||
|
secondary_label.place(x=20, y=530, width=400, height=20)
|
||||||
|
help = tk.Label(window, text="Next: ctrl+f, Mode: ctrl+x, Undo: ctrl+z, Save: ctrl+s", anchor="w")
|
||||||
|
help.place(x=20, y=550, width=400, height=20)
|
||||||
|
help2 = tk.Label(window, text=','.join(list(labeler.categories.keys())), anchor="w")
|
||||||
|
help2.place(x=20, y=570, width=400, height=20)
|
||||||
|
mode_label = tk.Label(window, text="Current mode: %s; Saved images: %i" % (labeler.str_labels[mode], 0), anchor="w")
|
||||||
|
mode_label.place(x=20, y=590, width=400, height=20)
|
||||||
|
|
||||||
|
widgets = (image_widget, primary_label, secondary_label, mode_label)
|
||||||
|
|
||||||
|
window.bind("<Control-x>", change_mode)
|
||||||
|
window.bind("<Control-z>", undo)
|
||||||
|
window.bind("<Control-s>", save)
|
||||||
|
window.bind("<Control-f>", next_image)
|
||||||
|
for kb in labeler.categories.keys():
|
||||||
|
window.bind("%s" % (kb,), key_press)
|
||||||
|
window.bind("<Tab>", key_press) # Skip current patch
|
||||||
|
window.mainloop()
|
|
@ -0,0 +1,52 @@
|
||||||
|
import logging
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
import utils
|
||||||
|
import utils.options as option
|
||||||
|
import utils.util as util
|
||||||
|
from data import create_dataset, create_dataloader
|
||||||
|
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class PretrainedImagePatchClassifier:
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
opt = option.parse(cfg, is_train=False)
|
||||||
|
opt = option.dict_to_nonedict(opt)
|
||||||
|
utils.util.loaded_options = opt
|
||||||
|
|
||||||
|
util.mkdirs(
|
||||||
|
(path for key, path in opt['path'].items()
|
||||||
|
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
|
||||||
|
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
|
||||||
|
screen=True, tofile=True)
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
logger.info(option.dict2str(opt))
|
||||||
|
|
||||||
|
#### Create test dataset and dataloader
|
||||||
|
dataset_opt = list(opt['datasets'].values())[0]
|
||||||
|
# Remove labeling features from the dataset config and wrappers.
|
||||||
|
if 'dataset' in dataset_opt.keys():
|
||||||
|
if 'labeler' in dataset_opt['dataset'].keys():
|
||||||
|
dataset_opt['dataset']['includes_labels'] = False
|
||||||
|
del dataset_opt['dataset']['labeler']
|
||||||
|
test_set = create_dataset(dataset_opt)
|
||||||
|
if hasattr(test_set, 'wrapped_dataset'):
|
||||||
|
test_set = test_set.wrapped_dataset
|
||||||
|
else:
|
||||||
|
test_set = create_dataset(dataset_opt)
|
||||||
|
logger.info('Number of test images: {:d}'.format(len(test_set)))
|
||||||
|
self.test_loader = create_dataloader(test_set, dataset_opt, opt)
|
||||||
|
self.model = ExtensibleTrainer(opt)
|
||||||
|
self.gen = self.model.netsG['generator']
|
||||||
|
self.dataset_dir = osp.join(opt['path']['results_root'], opt['name'])
|
||||||
|
util.mkdir(self.dataset_dir)
|
||||||
|
|
||||||
|
def get_next_sample(self):
|
||||||
|
|
||||||
|
for data in self.test_loader:
|
||||||
|
hq = data['hq'].to('cuda')
|
||||||
|
res = self.gen(hq)
|
||||||
|
yield hq, res, data
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
import utils.options as option
|
||||||
|
from scripts.ui.image_labeler.pretrained_image_patch_classifier import PretrainedImagePatchClassifier
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
#### options
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
want_metrics = False
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/train_imgset_structural_classifier.yml')
|
||||||
|
|
||||||
|
classifier = PretrainedImagePatchClassifier(parser.parse_args().opt)
|
||||||
|
label_to_search_for = 4
|
||||||
|
step = 1
|
||||||
|
for hq, res in classifier.get_next_sample():
|
||||||
|
res = torch.nn.functional.interpolate(res, size=hq.shape[2:], mode="nearest")
|
||||||
|
res_lbl = res[:, label_to_search_for, :, :].unsqueeze(1)
|
||||||
|
res_lbl_mask = (1.0 * (res_lbl > .5))*.5 + .5
|
||||||
|
hq = hq * res_lbl_mask
|
||||||
|
torchvision.utils.save_image(hq, os.path.join(classifier.dataset_dir, "%i.png" % (step,)))
|
||||||
|
step += 1
|
Loading…
Reference in New Issue
Block a user