import logging import os.path as osp import utils import utils.options as option import utils.util as util from data import create_dataset, create_dataloader from trainer.ExtensibleTrainer import ExtensibleTrainer class PretrainedImagePatchClassifier: def __init__(self, cfg): self.cfg = cfg opt = option.parse(cfg, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) #### Create test dataset and dataloader dataset_opt = list(opt['datasets'].values())[0] # Remove labeling features from the dataset config and wrappers. if 'dataset' in dataset_opt.keys(): if 'labeler' in dataset_opt['dataset'].keys(): dataset_opt['dataset']['includes_labels'] = False del dataset_opt['dataset']['labeler'] test_set = create_dataset(dataset_opt) if hasattr(test_set, 'wrapped_dataset'): test_set = test_set.wrapped_dataset else: test_set = create_dataset(dataset_opt) logger.info('Number of test images: {:d}'.format(len(test_set))) self.test_loader = create_dataloader(test_set, dataset_opt, opt) self.model = ExtensibleTrainer(opt) self.gen = self.model.netsG['generator'] self.dataset_dir = osp.join(opt['path']['results_root'], opt['name']) util.mkdir(self.dataset_dir) def get_next_sample(self): for data in self.test_loader: hq = data['hq'].to('cuda') res = self.gen(hq) yield hq, res, data