2020-12-17 17:16:21 +00:00
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torchvision
|
2023-03-21 15:39:28 +00:00
|
|
|
from scripts.ui.image_labeler.pretrained_image_patch_classifier import \
|
|
|
|
PretrainedImagePatchClassifier
|
2020-12-17 17:16:21 +00:00
|
|
|
|
2023-03-21 15:39:28 +00:00
|
|
|
import dlas.utils.options as option
|
2020-12-17 17:16:21 +00:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2023-03-21 15:39:28 +00:00
|
|
|
# options
|
2020-12-17 17:16:21 +00:00
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
want_metrics = False
|
|
|
|
parser = argparse.ArgumentParser()
|
2023-03-21 15:39:28 +00:00
|
|
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.',
|
|
|
|
default='../options/train_imgset_structural_classifier.yml')
|
2020-12-17 17:16:21 +00:00
|
|
|
|
|
|
|
classifier = PretrainedImagePatchClassifier(parser.parse_args().opt)
|
|
|
|
label_to_search_for = 4
|
|
|
|
step = 1
|
|
|
|
for hq, res in classifier.get_next_sample():
|
2023-03-21 15:39:28 +00:00
|
|
|
res = torch.nn.functional.interpolate(
|
|
|
|
res, size=hq.shape[2:], mode="nearest")
|
2020-12-17 17:16:21 +00:00
|
|
|
res_lbl = res[:, label_to_search_for, :, :].unsqueeze(1)
|
|
|
|
res_lbl_mask = (1.0 * (res_lbl > .5))*.5 + .5
|
|
|
|
hq = hq * res_lbl_mask
|
2023-03-21 15:39:28 +00:00
|
|
|
torchvision.utils.save_image(hq, os.path.join(
|
|
|
|
classifier.dataset_dir, "%i.png" % (step,)))
|
2020-12-17 17:16:21 +00:00
|
|
|
step += 1
|