DL-Art-School/codes/scripts/ui/image_labeler/test_image_patch_classifier.py

27 lines
999 B
Python
Raw Normal View History

2020-12-17 17:16:21 +00:00
import argparse
import os
import torch
import torchvision
import utils.options as option
from scripts.ui.image_labeler.pretrained_image_patch_classifier import PretrainedImagePatchClassifier
if __name__ == "__main__":
#### options
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/train_imgset_structural_classifier.yml')
classifier = PretrainedImagePatchClassifier(parser.parse_args().opt)
label_to_search_for = 4
step = 1
for hq, res in classifier.get_next_sample():
res = torch.nn.functional.interpolate(res, size=hq.shape[2:], mode="nearest")
res_lbl = res[:, label_to_search_for, :, :].unsqueeze(1)
res_lbl_mask = (1.0 * (res_lbl > .5))*.5 + .5
hq = hq * res_lbl_mask
torchvision.utils.save_image(hq, os.path.join(classifier.dataset_dir, "%i.png" % (step,)))
step += 1