1
0
forked from mrq/DL-Art-School
DL-Art-School/codes/scripts/ui/image_labeler/pretrained_image_patch_classifier.py
2020-12-18 09:18:34 -07:00

53 lines
2.0 KiB
Python

import logging
import os.path as osp
import utils
import utils.options as option
import utils.util as util
from data import create_dataset, create_dataloader
from trainer.ExtensibleTrainer import ExtensibleTrainer
class PretrainedImagePatchClassifier:
def __init__(self, cfg):
self.cfg = cfg
opt = option.parse(cfg, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
util.mkdirs(
(path for key, path in opt['path'].items()
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
#### Create test dataset and dataloader
dataset_opt = list(opt['datasets'].values())[0]
# Remove labeling features from the dataset config and wrappers.
if 'dataset' in dataset_opt.keys():
if 'labeler' in dataset_opt['dataset'].keys():
dataset_opt['dataset']['includes_labels'] = False
del dataset_opt['dataset']['labeler']
test_set = create_dataset(dataset_opt)
if hasattr(test_set, 'wrapped_dataset'):
test_set = test_set.wrapped_dataset
else:
test_set = create_dataset(dataset_opt)
logger.info('Number of test images: {:d}'.format(len(test_set)))
self.test_loader = create_dataloader(test_set, dataset_opt, opt)
self.model = ExtensibleTrainer(opt)
self.gen = self.model.netsG['generator']
self.dataset_dir = osp.join(opt['path']['results_root'], opt['name'])
util.mkdir(self.dataset_dir)
def get_next_sample(self):
for data in self.test_loader:
hq = data['hq'].to('cuda')
res = self.gen(hq)
yield hq, res, data