diff --git a/codes/trainer/eval/categorization_loss_eval.py b/codes/trainer/eval/categorization_loss_eval.py index 6f3bb6d0..37897412 100644 --- a/codes/trainer/eval/categorization_loss_eval.py +++ b/codes/trainer/eval/categorization_loss_eval.py @@ -6,6 +6,7 @@ from torchvision import transforms from tqdm import tqdm import trainer.eval.evaluator as evaluator +from data import create_dataset from models.vqvae.kmeans_mask_producer import UResnetMaskProducer from utils.util import opt_get @@ -17,14 +18,7 @@ class CategorizationLossEvaluator(evaluator.Evaluator): assert self.batch_sz is not None normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - self.dataset = torchvision.datasets.ImageFolder( - 'E:\\4k6k\\datasets\\images\\imagenet_2017\\val', - transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])) + self.dataset = create_dataset(opt_eval['dataset']) self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4) self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 self.masking = opt_get(opt_eval, ['masking'], False)