forked from mrq/DL-Art-School
Add an image patch labeling UI
This commit is contained in:
parent
daee1b5572
commit
12cf052889
|
@ -134,7 +134,7 @@ if __name__ == '__main__':
|
|||
'corrupt_before_downsize': True,
|
||||
'labeler': {
|
||||
'type': 'patch_labels',
|
||||
'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories.json'
|
||||
'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json'
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import orjson as json
|
||||
# Given a JSON file produced by the VS.net image labeler utility, produces a dict where the keys are image file names
|
||||
|
@ -9,39 +10,40 @@ import torch
|
|||
|
||||
class VsNetImageLabeler:
|
||||
def __init__(self, label_file):
|
||||
with open(label_file, "r") as read_file:
|
||||
# Format of JSON file:
|
||||
# "<nonsense>" {
|
||||
# "label": "<label>"
|
||||
# "keyBinding": "<nonsense>"
|
||||
# "labeledImages": [
|
||||
# { "path", "label", "patch_top", "patch_left", "patch_height", "patch_width" }
|
||||
# ]
|
||||
# }
|
||||
categories = json.loads(read_file.read())
|
||||
labeled_images = {}
|
||||
available_labels = []
|
||||
for cat in categories.values():
|
||||
for lbli in cat['labeledImages']:
|
||||
pth = lbli['path']
|
||||
if pth not in labeled_images.keys():
|
||||
labeled_images[pth] = []
|
||||
labeled_images[pth].append(lbli)
|
||||
if lbli['label'] not in available_labels:
|
||||
available_labels.append(lbli['label'])
|
||||
if not isinstance(label_file, list):
|
||||
label_file = [label_file]
|
||||
self.labeled_images = {}
|
||||
for lfil in label_file:
|
||||
with open(lfil, "r") as read_file:
|
||||
self.label_file = label_file
|
||||
# Format of JSON file:
|
||||
# "key_binding" {
|
||||
# "label": "<label>"
|
||||
# "index": <num>
|
||||
# "keyBinding": "key_binding"
|
||||
# "labeledImages": [
|
||||
# { "path", "label", "patch_top", "patch_left", "patch_height", "patch_width" }
|
||||
# ]
|
||||
# }
|
||||
categories = json.loads(read_file.read())
|
||||
available_labels = {}
|
||||
label_value_dict = {}
|
||||
for cat in categories.values():
|
||||
available_labels[cat['index']] = cat['label']
|
||||
label_value_dict[cat['label']] = cat['index']
|
||||
for lbli in cat['labeledImages']:
|
||||
pth = lbli['path']
|
||||
if pth not in self.labeled_images.keys():
|
||||
self.labeled_images[pth] = []
|
||||
self.labeled_images[pth].append(lbli)
|
||||
|
||||
# Build the label values, from [1,inf]
|
||||
label_value_dict = {}
|
||||
for i, l in enumerate(available_labels):
|
||||
label_value_dict[l] = i
|
||||
# Insert "labelValue" for each entry.
|
||||
for v in self.labeled_images.values():
|
||||
for l in v:
|
||||
l['labelValue'] = label_value_dict[l['label']]
|
||||
|
||||
# Insert "labelValue" for each entry.
|
||||
for v in labeled_images.values():
|
||||
for l in v:
|
||||
l['labelValue'] = label_value_dict[l['label']]
|
||||
|
||||
self.labeled_images = labeled_images
|
||||
self.str_labels = available_labels
|
||||
self.categories = categories
|
||||
self.str_labels = available_labels
|
||||
|
||||
def get_labeled_paths(self, base_path):
|
||||
return [os.path.join(base_path, pth) for pth in self.labeled_images]
|
||||
|
@ -57,4 +59,13 @@ class VsNetImageLabeler:
|
|||
val = patch_lbl['labelValue']
|
||||
labels[:,t:t+h,l:l+w] = val
|
||||
mask[:,t:t+h,l:l+w] = 1.0
|
||||
return labels, mask, self.str_labels
|
||||
return labels, mask, self.str_labels
|
||||
|
||||
def add_label(self, binding, img_name, top, left, dim):
|
||||
lbl = {"path": img_name, "label": self.categories[binding]['label'], "patch_top": top, "patch_left": left,
|
||||
"patch_height": dim, "patch_width": dim}
|
||||
self.categories[binding]['labeledImages'].append(lbl)
|
||||
|
||||
def save(self):
|
||||
with open(self.label_file[0], "wb") as file:
|
||||
file.write(json.dumps(self.categories))
|
||||
|
|
181
codes/scripts/ui/image_labeler/image_labeler_ui.py
Normal file
181
codes/scripts/ui/image_labeler/image_labeler_ui.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
# Script that builds and launches a tkinter UI for directly interacting with a pytorch model that performs
|
||||
# patch-level image classification.
|
||||
#
|
||||
# This script is a bit rough around the edges. I threw it together quickly to ascertain its usefulness. If I end up
|
||||
# using it a lot, I may re-visit it.
|
||||
|
||||
import os
|
||||
import tkinter as tk
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import ImageTk
|
||||
|
||||
from data.image_label_parser import VsNetImageLabeler
|
||||
from scripts.ui.image_labeler.pretrained_image_patch_classifier import PretrainedImagePatchClassifier
|
||||
|
||||
# Globals used to define state that event handlers might operate on.
|
||||
classifier = None
|
||||
gen = None
|
||||
labeler = None
|
||||
to_pil = torchvision.transforms.ToPILImage()
|
||||
widgets = None
|
||||
batch_gen = None
|
||||
cur_img = 0
|
||||
batch_sz = 0
|
||||
mode = 0
|
||||
cur_path, cur_top, cur_left, cur_dim = None, None, None, None
|
||||
pending_labels = []
|
||||
|
||||
|
||||
def update_mode_label():
|
||||
global widgets
|
||||
image_widget, primary_label, secondary_label, mode_label = widgets
|
||||
mode_label.config(text="Current mode: %s; Saved images: %i" % (labeler.str_labels[mode], len(pending_labels)))
|
||||
|
||||
|
||||
# Handles the "change mode" hotkey. Changes the classification label being targeted.
|
||||
def change_mode(event):
|
||||
global mode, pending_labels
|
||||
mode += 1
|
||||
update_mode_label()
|
||||
|
||||
|
||||
# Handles key presses, which are interpreted as requests to categorize a currently active image patch.
|
||||
def key_press(event):
|
||||
global batch_gen, labeler, pending_labels
|
||||
|
||||
if event.char != '\t':
|
||||
if event.char not in labeler.categories.keys():
|
||||
print("Specified category doesn't exist.")
|
||||
return
|
||||
cat = labeler.categories[event.char]
|
||||
img_name = os.path.basename(cur_path)
|
||||
print(cat['label'], img_name, cur_top, cur_left, cur_dim)
|
||||
pending_labels.append([event.char, img_name, cur_top, cur_left, cur_dim])
|
||||
|
||||
if not next(batch_gen):
|
||||
next_image(None)
|
||||
update_mode_label()
|
||||
|
||||
|
||||
# Pop the most recent label off of the stack.
|
||||
def undo(event):
|
||||
global pending_labels
|
||||
c, nm, t, l, d = pending_labels.pop()
|
||||
print("Removed pending label", c, nm, t, l, d)
|
||||
update_mode_label()
|
||||
|
||||
|
||||
# Save the stack of pending labels to the underlying label file.
|
||||
def save(event):
|
||||
global pending_labels, labeler
|
||||
print("Saving %i labels", len(pending_labels,))
|
||||
for l in pending_labels:
|
||||
labeler.add_label(*l)
|
||||
labeler.save()
|
||||
pending_labels = []
|
||||
update_mode_label()
|
||||
|
||||
|
||||
# This is the main controller for the state machine that this UI attaches to. At its core, it performs inference on
|
||||
# a batch of images, then repetitively yields patches of images that fit within a confidence bound for the currently
|
||||
# active label.
|
||||
def next_batch():
|
||||
global gen, widgets, cur_img, batch_sz, cur_top, cur_left, cur_dim, mode, labeler, cur_path
|
||||
image_widget, primary_label, secondary_label, mode_label = widgets
|
||||
hq, res, data = next(gen)
|
||||
scale = hq.shape[-1] // res.shape[-1]
|
||||
|
||||
# These are the confidence bounds. They apply to a post-softmax output. They are currently fixed.
|
||||
conf_lower = .8
|
||||
conf_upper = 1
|
||||
valid_res = ((res > conf_lower) * (res < conf_upper)) * 1.0 # This results in a tensor of 0's where tensors are outside of the confidence bound, 1's where inside.
|
||||
|
||||
cur_img = 0
|
||||
batch_sz = hq.shape[0]
|
||||
while cur_img < batch_sz: # Note: cur_img can (intentionally) be changed outside of this loop.
|
||||
# Build a random permutation for every image patch in the image. We will search for patches that fall within the confidence bound and yield them.
|
||||
#permutation = torch.randperm(res.shape[-1] * res.shape[-2])
|
||||
for p in range(res.shape[-1]*res.shape[-2]):
|
||||
# Reconstruct a top & left coordinate.
|
||||
t = p // res.shape[-1]
|
||||
l = p % res.shape[-1]
|
||||
if not valid_res[cur_img,mode,t,l]:
|
||||
continue
|
||||
|
||||
# Build a mask that shows the user the underlying construction of the image.
|
||||
# - Areas that don't fit the current labeling restrictions get a mask value of .25.
|
||||
# - Areas that do fit but are not the current patch, get a mask value of .5
|
||||
# - The current patch gets a mask value of 1.0
|
||||
# Expected output shape is (1,h,w) so it can be multiplied into the output image.
|
||||
mask = (valid_res[cur_img,mode,:,:].clone()*.25 + .25).unsqueeze(0)
|
||||
mask[:,t,l] = 1.0
|
||||
|
||||
# Interpolate the mask so that it can be directly multiplied against the HQ image.
|
||||
masked = hq[cur_img,:,:,:].clone() * torch.nn.functional.interpolate(mask.unsqueeze(0), scale_factor=scale, mode="nearest").squeeze(0)
|
||||
|
||||
# Update the image widget to show the new masked image.
|
||||
tk_picture = ImageTk.PhotoImage(to_pil(masked))
|
||||
image_widget.image = tk_picture
|
||||
image_widget.configure(image=tk_picture)
|
||||
|
||||
# Fill in the labels
|
||||
probs = res[cur_img, :, t, l]
|
||||
probs, lblis = torch.topk(probs, k=2)
|
||||
primary_label.config(text="%s (p=%f)" % (labeler.str_labels[lblis[0].item()], probs[0]))
|
||||
secondary_label.config(text="%s (p=%f)" % (labeler.str_labels[lblis[1].item()], probs[1]))
|
||||
|
||||
# Update state variables so that the key handlers can save the current patch as needed.
|
||||
cur_top, cur_left, cur_dim = (t*scale), (l*scale), scale
|
||||
cur_path = os.path.basename(data['HQ_path'][cur_img])
|
||||
yield True
|
||||
cur_img += 1
|
||||
cur_top, cur_left, cur_dim = None, None, None
|
||||
return False
|
||||
|
||||
|
||||
def next_image(event):
|
||||
global batch_gen, batch_sz, cur_img
|
||||
cur_img += 1
|
||||
if cur_img >= batch_sz:
|
||||
cur_img = 0
|
||||
batch_gen = next_batch()
|
||||
next(batch_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
classifier = PretrainedImagePatchClassifier('../options/train_imgset_structural_classifier.yml')
|
||||
gen = classifier.get_next_sample()
|
||||
labeler = VsNetImageLabeler('F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json')
|
||||
|
||||
window = tk.Tk()
|
||||
window.title("Image labeler UI")
|
||||
window.geometry('512x620+100+100')
|
||||
|
||||
# Photo view.
|
||||
image_widget = tk.Label(window)
|
||||
image_widget.place(x=0, y=0, width=512, height=512)
|
||||
|
||||
# Labels
|
||||
primary_label = tk.Label(window, text="xxxx (p=1.0)", anchor="w")
|
||||
primary_label.place(x=20, y=510, width=400, height=20)
|
||||
secondary_label = tk.Label(window, text="yyyy (p=0.0)", anchor="w")
|
||||
secondary_label.place(x=20, y=530, width=400, height=20)
|
||||
help = tk.Label(window, text="Next: ctrl+f, Mode: ctrl+x, Undo: ctrl+z, Save: ctrl+s", anchor="w")
|
||||
help.place(x=20, y=550, width=400, height=20)
|
||||
help2 = tk.Label(window, text=','.join(list(labeler.categories.keys())), anchor="w")
|
||||
help2.place(x=20, y=570, width=400, height=20)
|
||||
mode_label = tk.Label(window, text="Current mode: %s; Saved images: %i" % (labeler.str_labels[mode], 0), anchor="w")
|
||||
mode_label.place(x=20, y=590, width=400, height=20)
|
||||
|
||||
widgets = (image_widget, primary_label, secondary_label, mode_label)
|
||||
|
||||
window.bind("<Control-x>", change_mode)
|
||||
window.bind("<Control-z>", undo)
|
||||
window.bind("<Control-s>", save)
|
||||
window.bind("<Control-f>", next_image)
|
||||
for kb in labeler.categories.keys():
|
||||
window.bind("%s" % (kb,), key_press)
|
||||
window.bind("<Tab>", key_press) # Skip current patch
|
||||
window.mainloop()
|
|
@ -0,0 +1,52 @@
|
|||
import logging
|
||||
import os.path as osp
|
||||
|
||||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
from data import create_dataset, create_dataloader
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
|
||||
|
||||
class PretrainedImagePatchClassifier:
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
|
||||
opt = option.parse(cfg, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
||||
util.mkdirs(
|
||||
(path for key, path in opt['path'].items()
|
||||
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
|
||||
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
|
||||
screen=True, tofile=True)
|
||||
logger = logging.getLogger('base')
|
||||
logger.info(option.dict2str(opt))
|
||||
|
||||
#### Create test dataset and dataloader
|
||||
dataset_opt = list(opt['datasets'].values())[0]
|
||||
# Remove labeling features from the dataset config and wrappers.
|
||||
if 'dataset' in dataset_opt.keys():
|
||||
if 'labeler' in dataset_opt['dataset'].keys():
|
||||
dataset_opt['dataset']['includes_labels'] = False
|
||||
del dataset_opt['dataset']['labeler']
|
||||
test_set = create_dataset(dataset_opt)
|
||||
if hasattr(test_set, 'wrapped_dataset'):
|
||||
test_set = test_set.wrapped_dataset
|
||||
else:
|
||||
test_set = create_dataset(dataset_opt)
|
||||
logger.info('Number of test images: {:d}'.format(len(test_set)))
|
||||
self.test_loader = create_dataloader(test_set, dataset_opt, opt)
|
||||
self.model = ExtensibleTrainer(opt)
|
||||
self.gen = self.model.netsG['generator']
|
||||
self.dataset_dir = osp.join(opt['path']['results_root'], opt['name'])
|
||||
util.mkdir(self.dataset_dir)
|
||||
|
||||
def get_next_sample(self):
|
||||
|
||||
for data in self.test_loader:
|
||||
hq = data['hq'].to('cuda')
|
||||
res = self.gen(hq)
|
||||
yield hq, res, data
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
import utils.options as option
|
||||
from scripts.ui.image_labeler.pretrained_image_patch_classifier import PretrainedImagePatchClassifier
|
||||
|
||||
if __name__ == "__main__":
|
||||
#### options
|
||||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/train_imgset_structural_classifier.yml')
|
||||
|
||||
classifier = PretrainedImagePatchClassifier(parser.parse_args().opt)
|
||||
label_to_search_for = 4
|
||||
step = 1
|
||||
for hq, res in classifier.get_next_sample():
|
||||
res = torch.nn.functional.interpolate(res, size=hq.shape[2:], mode="nearest")
|
||||
res_lbl = res[:, label_to_search_for, :, :].unsqueeze(1)
|
||||
res_lbl_mask = (1.0 * (res_lbl > .5))*.5 + .5
|
||||
hq = hq * res_lbl_mask
|
||||
torchvision.utils.save_image(hq, os.path.join(classifier.dataset_dir, "%i.png" % (step,)))
|
||||
step += 1
|
Loading…
Reference in New Issue
Block a user