forked from mrq/DL-Art-School
53 lines
2.0 KiB
Python
53 lines
2.0 KiB
Python
import logging
|
|
import os.path as osp
|
|
|
|
import utils
|
|
import utils.options as option
|
|
import utils.util as util
|
|
from data import create_dataset, create_dataloader
|
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
|
|
|
|
|
class PretrainedImagePatchClassifier:
|
|
def __init__(self, cfg):
|
|
self.cfg = cfg
|
|
|
|
opt = option.parse(cfg, is_train=False)
|
|
opt = option.dict_to_nonedict(opt)
|
|
utils.util.loaded_options = opt
|
|
|
|
util.mkdirs(
|
|
(path for key, path in opt['path'].items()
|
|
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
|
|
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
|
|
screen=True, tofile=True)
|
|
logger = logging.getLogger('base')
|
|
logger.info(option.dict2str(opt))
|
|
|
|
#### Create test dataset and dataloader
|
|
dataset_opt = list(opt['datasets'].values())[0]
|
|
# Remove labeling features from the dataset config and wrappers.
|
|
if 'dataset' in dataset_opt.keys():
|
|
if 'labeler' in dataset_opt['dataset'].keys():
|
|
dataset_opt['dataset']['includes_labels'] = False
|
|
del dataset_opt['dataset']['labeler']
|
|
test_set = create_dataset(dataset_opt)
|
|
if hasattr(test_set, 'wrapped_dataset'):
|
|
test_set = test_set.wrapped_dataset
|
|
else:
|
|
test_set = create_dataset(dataset_opt)
|
|
logger.info('Number of test images: {:d}'.format(len(test_set)))
|
|
self.test_loader = create_dataloader(test_set, dataset_opt, opt)
|
|
self.model = ExtensibleTrainer(opt)
|
|
self.gen = self.model.netsG['generator']
|
|
self.dataset_dir = osp.join(opt['path']['results_root'], opt['name'])
|
|
util.mkdir(self.dataset_dir)
|
|
|
|
def get_next_sample(self):
|
|
|
|
for data in self.test_loader:
|
|
hq = data['hq'].to('cuda')
|
|
res = self.gen(hq)
|
|
yield hq, res, data
|
|
|