This commit is contained in:
James Betker 2021-06-04 21:24:48 -06:00
parent e6c537824a
commit 6c8c8087d5

View File

@ -6,6 +6,7 @@ from torchvision import transforms
from tqdm import tqdm from tqdm import tqdm
import trainer.eval.evaluator as evaluator import trainer.eval.evaluator as evaluator
from data import create_dataset
from models.vqvae.kmeans_mask_producer import UResnetMaskProducer from models.vqvae.kmeans_mask_producer import UResnetMaskProducer
from utils.util import opt_get from utils.util import opt_get
@ -17,14 +18,7 @@ class CategorizationLossEvaluator(evaluator.Evaluator):
assert self.batch_sz is not None assert self.batch_sz is not None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) std=[0.229, 0.224, 0.225])
self.dataset = torchvision.datasets.ImageFolder( self.dataset = create_dataset(opt_eval['dataset'])
'E:\\4k6k\\datasets\\images\\imagenet_2017\\val',
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4) self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4)
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
self.masking = opt_get(opt_eval, ['masking'], False) self.masking = opt_get(opt_eval, ['masking'], False)