forked from mrq/DL-Art-School
asdf
This commit is contained in:
parent
e6c537824a
commit
6c8c8087d5
|
@ -6,6 +6,7 @@ from torchvision import transforms
|
|||
from tqdm import tqdm
|
||||
|
||||
import trainer.eval.evaluator as evaluator
|
||||
from data import create_dataset
|
||||
from models.vqvae.kmeans_mask_producer import UResnetMaskProducer
|
||||
from utils.util import opt_get
|
||||
|
||||
|
@ -17,14 +18,7 @@ class CategorizationLossEvaluator(evaluator.Evaluator):
|
|||
assert self.batch_sz is not None
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
self.dataset = torchvision.datasets.ImageFolder(
|
||||
'E:\\4k6k\\datasets\\images\\imagenet_2017\\val',
|
||||
transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
self.dataset = create_dataset(opt_eval['dataset'])
|
||||
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4)
|
||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||
self.masking = opt_get(opt_eval, ['masking'], False)
|
||||
|
|
Loading…
Reference in New Issue
Block a user